mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-11-04 23:10:35 +01:00
Add "auth-tokens"
This commit is contained in:
parent
149c13e9d8
commit
23ec7702fc
10 changed files with 263 additions and 88 deletions
56
cmd/serve.go
56
cmd/serve.go
|
|
@ -50,6 +50,7 @@ var flagsServe = append(
|
|||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-users", Aliases: []string{"auth_users"}, EnvVars: []string{"NTFY_AUTH_USERS"}, Usage: "pre-provisioned declarative users"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
|
||||
|
|
@ -158,6 +159,7 @@ func execServe(c *cli.Context) error {
|
|||
authDefaultAccess := c.String("auth-default-access")
|
||||
authUsersRaw := c.StringSlice("auth-users")
|
||||
authAccessRaw := c.StringSlice("auth-access")
|
||||
authTokensRaw := c.StringSlice("auth-tokens")
|
||||
attachmentCacheDir := c.String("attachment-cache-dir")
|
||||
attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
|
||||
attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
|
||||
|
|
@ -361,6 +363,10 @@ func execServe(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authTokens, err := parseTokens(authUsers, authTokensRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Special case: Unset default
|
||||
if listenHTTP == "-" {
|
||||
|
|
@ -418,6 +424,7 @@ func execServe(c *cli.Context) error {
|
|||
conf.AuthDefault = authDefault
|
||||
conf.AuthUsers = authUsers
|
||||
conf.AuthAccess = authAccess
|
||||
conf.AuthTokens = authTokens
|
||||
conf.AttachmentCacheDir = attachmentCacheDir
|
||||
conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
|
||||
conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
|
||||
|
|
@ -532,7 +539,7 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
|
|||
}
|
||||
|
||||
func parseUsers(usersRaw []string) ([]*user.User, error) {
|
||||
provisionUsers := make([]*user.User, 0)
|
||||
users := make([]*user.User, 0)
|
||||
for _, userLine := range usersRaw {
|
||||
parts := strings.Split(userLine, ":")
|
||||
if len(parts) != 3 {
|
||||
|
|
@ -548,19 +555,19 @@ func parseUsers(usersRaw []string) ([]*user.User, error) {
|
|||
} else if !user.AllowedRole(role) {
|
||||
return nil, fmt.Errorf("invalid auth-users: %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role)
|
||||
}
|
||||
provisionUsers = append(provisionUsers, &user.User{
|
||||
users = append(users, &user.User{
|
||||
Name: username,
|
||||
Hash: passwordHash,
|
||||
Role: role,
|
||||
Provisioned: true,
|
||||
})
|
||||
}
|
||||
return provisionUsers, nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[string][]*user.Grant, error) {
|
||||
func parseAccess(users []*user.User, accessRaw []string) (map[string][]*user.Grant, error) {
|
||||
access := make(map[string][]*user.Grant)
|
||||
for _, accessLine := range provisionAccessRaw {
|
||||
for _, accessLine := range accessRaw {
|
||||
parts := strings.Split(accessLine, ":")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid auth-access: %s, expected format: 'user:topic:permission'", accessLine)
|
||||
|
|
@ -569,7 +576,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
|
|||
if username == userEveryone {
|
||||
username = user.Everyone
|
||||
}
|
||||
provisionUser, exists := util.Find(provisionUsers, func(u *user.User) bool {
|
||||
u, exists := util.Find(users, func(u *user.User) bool {
|
||||
return u.Name == username
|
||||
})
|
||||
if username != user.Everyone {
|
||||
|
|
@ -577,7 +584,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
|
|||
return nil, fmt.Errorf("invalid auth-access: %s, user %s is not provisioned", accessLine, username)
|
||||
} else if !user.AllowedUsername(username) {
|
||||
return nil, fmt.Errorf("invalid auth-access: %s, username %s invalid", accessLine, username)
|
||||
} else if provisionUser.Role != user.RoleUser {
|
||||
} else if u.Role != user.RoleUser {
|
||||
return nil, fmt.Errorf("invalid auth-access: %s, user %s is not a regular user, only regular users can have ACL entries", accessLine, username)
|
||||
}
|
||||
}
|
||||
|
|
@ -601,6 +608,41 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
|
|||
return access, nil
|
||||
}
|
||||
|
||||
func parseTokens(users []*user.User, tokensRaw []string) (map[string][]*user.Token, error) {
|
||||
tokens := make(map[string][]*user.Token)
|
||||
for _, tokenLine := range tokensRaw {
|
||||
parts := strings.Split(tokenLine, ":")
|
||||
if len(parts) < 2 || len(parts) > 3 {
|
||||
return nil, fmt.Errorf("invalid auth-tokens: %s, expected format: 'user:token[:label]'", tokenLine)
|
||||
}
|
||||
username := strings.TrimSpace(parts[0])
|
||||
_, exists := util.Find(users, func(u *user.User) bool {
|
||||
return u.Name == username
|
||||
})
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("invalid auth-tokens: %s, user %s is not provisioned", tokenLine, username)
|
||||
} else if !user.AllowedUsername(username) {
|
||||
return nil, fmt.Errorf("invalid auth-tokens: %s, username %s invalid", tokenLine, username)
|
||||
}
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if !user.AllowedToken(token) {
|
||||
return nil, fmt.Errorf("invalid auth-tokens: %s, token %s invalid, use 'ntfy token generate' to generate a random token", tokenLine, token)
|
||||
}
|
||||
var label string
|
||||
if len(parts) > 2 {
|
||||
label = parts[2]
|
||||
}
|
||||
if _, exists := tokens[username]; !exists {
|
||||
tokens[username] = make([]*user.Token, 0)
|
||||
}
|
||||
tokens[username] = append(tokens[username], &user.Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
})
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func reloadLogLevel(inputSource altsrc.InputSourceContext) error {
|
||||
newLevelStr, err := inputSource.String("log-level")
|
||||
if err != nil {
|
||||
|
|
|
|||
29
cmd/token.go
29
cmd/token.go
|
|
@ -72,6 +72,15 @@ Example:
|
|||
This is a server-only command. It directly reads from user.db as defined in the server config
|
||||
file server.yml. The command only works if 'auth-file' is properly defined.`,
|
||||
},
|
||||
{
|
||||
Name: "generate",
|
||||
Usage: "Generates a random token",
|
||||
Action: execTokenGenerate,
|
||||
Description: `Randomly generate a token to be used in provisioned tokens.
|
||||
|
||||
This command only generates the token value, but does not persist it anywhere.
|
||||
The output can be used in the 'auth-tokens' config option.`,
|
||||
},
|
||||
},
|
||||
Description: `Manage access tokens for individual users.
|
||||
|
||||
|
|
@ -112,12 +121,12 @@ func execTokenAdd(c *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
u, err := manager.User(username)
|
||||
if err == user.ErrUserNotFound {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified())
|
||||
token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -141,7 +150,7 @@ func execTokenDel(c *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
u, err := manager.User(username)
|
||||
if err == user.ErrUserNotFound {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if err != nil {
|
||||
return err
|
||||
|
|
@ -165,7 +174,7 @@ func execTokenList(c *cli.Context) error {
|
|||
var users []*user.User
|
||||
if username != "" {
|
||||
u, err := manager.User(username)
|
||||
if err == user.ErrUserNotFound {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
return fmt.Errorf("user %s does not exist", username)
|
||||
} else if err != nil {
|
||||
return err
|
||||
|
|
@ -191,7 +200,7 @@ func execTokenList(c *cli.Context) error {
|
|||
usersWithTokens++
|
||||
fmt.Fprintf(c.App.ErrWriter, "user %s\n", u.Name)
|
||||
for _, t := range tokens {
|
||||
var label, expires string
|
||||
var label, expires, provisioned string
|
||||
if t.Label != "" {
|
||||
label = fmt.Sprintf(" (%s)", t.Label)
|
||||
}
|
||||
|
|
@ -200,7 +209,10 @@ func execTokenList(c *cli.Context) error {
|
|||
} else {
|
||||
expires = fmt.Sprintf("expires %s", t.Expires.Format(time.RFC822))
|
||||
}
|
||||
fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822))
|
||||
if t.Provisioned {
|
||||
provisioned = " (server config)"
|
||||
}
|
||||
fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s%s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822), provisioned)
|
||||
}
|
||||
}
|
||||
if usersWithTokens == 0 {
|
||||
|
|
@ -208,3 +220,8 @@ func execTokenList(c *cli.Context) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func execTokenGenerate(c *cli.Context) error {
|
||||
fmt.Println(user.GenerateToken())
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ type Config struct {
|
|||
AuthDefault user.Permission
|
||||
AuthUsers []*user.User
|
||||
AuthAccess map[string][]*user.Grant
|
||||
AuthTokens map[string][]*user.Token
|
||||
AuthBcryptCost int
|
||||
AuthStatsQueueWriterInterval time.Duration
|
||||
AttachmentCacheDir string
|
||||
|
|
|
|||
|
|
@ -203,6 +203,7 @@ func New(conf *Config) (*Server, error) {
|
|||
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||
Users: conf.AuthUsers,
|
||||
Access: conf.AuthAccess,
|
||||
Tokens: conf.AuthTokens,
|
||||
BcryptCost: conf.AuthBcryptCost,
|
||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,8 @@
|
|||
# Each entry is in the format "<username>:<password-hash>:<role>", e.g. "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:user"
|
||||
# - auth-access is a list of access control entries that are automatically created when the server starts.
|
||||
# Each entry is in the format "<username>:<topic-pattern>:<access>", e.g. "phil:mytopic:rw" or "phil:phil-*:rw".
|
||||
# - auth-tokens is a list of access tokens that are automatically created when the server starts.
|
||||
# Each entry is in the format "<username>:<token>[:<label>]", e.g. "phil:tk_1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef:My token".
|
||||
#
|
||||
# Debian/RPM package users:
|
||||
# Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
|
|||
"token_expires": expires,
|
||||
}).
|
||||
Debug("Creating token for user %s", u.Name)
|
||||
token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP())
|
||||
token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
u, _ := s.userManager.User("phil")
|
||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
|
||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
|
||||
|
||||
rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
|
|
|||
163
user/manager.go
163
user/manager.go
|
|
@ -111,9 +111,11 @@ const (
|
|||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
|
|
@ -181,16 +183,17 @@ const (
|
|||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||
selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
|
||||
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
|
||||
upsertUserAccessQuery = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
|
|
@ -255,17 +258,23 @@ const (
|
|||
AND topic = ?
|
||||
`
|
||||
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
upsertTokenQuery = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||
`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
|
|
@ -470,7 +479,7 @@ const (
|
|||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
|
|
@ -480,7 +489,8 @@ const (
|
|||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created, deleted
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
|
|
@ -500,10 +510,27 @@ const (
|
|||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
|
|
@ -537,7 +564,8 @@ type Config struct {
|
|||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
|
|
@ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
|||
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
|
||||
token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||
return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
access := time.Now()
|
||||
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
|
||||
if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := tx.Query(selectTokenCountQuery, userID)
|
||||
|
|
@ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: access,
|
||||
LastOrigin: origin,
|
||||
Expires: expires,
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: access,
|
||||
LastOrigin: origin,
|
||||
Expires: expires,
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
|||
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
var provisioned bool
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
|
@ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
|||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
|||
phoneNumbers := make([]string, 0)
|
||||
for {
|
||||
phoneNumber, err := a.readPhoneNumber(rows)
|
||||
if err == ErrPhoneNumberNotFound {
|
||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Remove and (re-)add provisioned tokens
|
||||
if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
for username, tokens := range a.config.Tokens {
|
||||
_, exists := util.Find(a.config.Users, func(u *User) bool {
|
||||
return u.Name == username
|
||||
})
|
||||
if !exists && username != Everyone {
|
||||
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
||||
}
|
||||
var userID string
|
||||
row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
|
||||
if err := row.Scan(&userID); err != nil {
|
||||
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
|
||||
}
|
||||
for _, token := range tokens {
|
||||
if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
@ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// queryTx executes a function in a transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back.
|
||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
require.False(t, u.Deleted)
|
||||
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err = a.Authenticate("user", "pass")
|
||||
|
|
@ -241,7 +241,7 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) {
|
|||
u, err := a.User("user")
|
||||
require.Nil(t, err)
|
||||
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, token.Value, strings.ToLower(token.Value))
|
||||
}
|
||||
|
|
@ -523,7 +523,7 @@ func TestManager_Token_Valid(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// Create token for user
|
||||
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
require.Equal(t, "some label", token.Label)
|
||||
|
|
@ -586,12 +586,12 @@ func TestManager_Token_Expire(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// Create tokens for user
|
||||
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token1.Value)
|
||||
require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
|
||||
|
||||
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token2.Value)
|
||||
require.NotEqual(t, token1.Value, token2.Value)
|
||||
|
|
@ -638,7 +638,7 @@ func TestManager_Token_Extend(t *testing.T) {
|
|||
require.Equal(t, errNoTokenProvided, err)
|
||||
|
||||
// Create token for user
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
|
||||
|
|
@ -668,12 +668,12 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||
|
||||
// Create 2 tokens for phil
|
||||
philTokens := make([]string, 0)
|
||||
token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
philTokens = append(philTokens, token.Value)
|
||||
|
||||
token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
|
||||
token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
philTokens = append(philTokens, token.Value)
|
||||
|
|
@ -682,7 +682,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||
baseTime := time.Now().Add(24 * time.Hour)
|
||||
benTokens := make([]string, 0)
|
||||
for i := 0; i < 62; i++ { //
|
||||
token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, token.Value)
|
||||
benTokens = append(benTokens, token.Value)
|
||||
|
|
@ -795,7 +795,7 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
|||
u, err := a.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Queue token update
|
||||
|
|
@ -1112,6 +1112,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||
{TopicPattern: "secret", Permission: PermissionRead},
|
||||
},
|
||||
},
|
||||
Tokens: map[string][]*Token{
|
||||
"philuser": {
|
||||
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
|
|
@ -1123,24 +1128,29 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||
users, err := a.Users()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, users, 4)
|
||||
|
||||
require.Equal(t, "philadmin", users[0].Name)
|
||||
require.Equal(t, RoleAdmin, users[0].Role)
|
||||
|
||||
require.Equal(t, "philmanual", users[1].Name)
|
||||
require.Equal(t, RoleUser, users[1].Role)
|
||||
require.Equal(t, "philuser", users[2].Name)
|
||||
require.Equal(t, RoleUser, users[2].Role)
|
||||
require.Equal(t, "*", users[3].Name)
|
||||
provisionedUserID := users[2].ID // "philuser" is the provisioned user
|
||||
|
||||
grants, err := a.Grants("philuser")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "philuser", users[2].Name)
|
||||
require.Equal(t, RoleUser, users[2].Role)
|
||||
require.Equal(t, 2, len(grants))
|
||||
require.Equal(t, "secret", grants[0].TopicPattern)
|
||||
require.Equal(t, PermissionRead, grants[0].Permission)
|
||||
require.Equal(t, "stats", grants[1].TopicPattern)
|
||||
require.Equal(t, PermissionReadWrite, grants[1].Permission)
|
||||
|
||||
require.Equal(t, "*", users[3].Name)
|
||||
tokens, err := a.Tokens(provisionedUserID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(tokens))
|
||||
require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value)
|
||||
require.Equal(t, "Alerts token", tokens[0].Label)
|
||||
require.True(t, tokens[0].Provisioned)
|
||||
|
||||
// Re-open the DB (second app start)
|
||||
require.Nil(t, a.db.Close())
|
||||
|
|
@ -1153,6 +1163,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||
{TopicPattern: "secret12", Permission: PermissionRead},
|
||||
},
|
||||
}
|
||||
conf.Tokens = map[string][]*Token{
|
||||
"philuser": {
|
||||
{Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"},
|
||||
},
|
||||
}
|
||||
a, err = NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
|
||||
|
|
@ -1160,30 +1175,36 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||
users, err = a.Users()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, users, 3)
|
||||
|
||||
require.Equal(t, "philmanual", users[0].Name)
|
||||
require.Equal(t, "philuser", users[1].Name)
|
||||
require.Equal(t, RoleUser, users[1].Role)
|
||||
require.Equal(t, RoleUser, users[0].Role)
|
||||
require.Equal(t, "*", users[2].Name)
|
||||
|
||||
grants, err = a.Grants("philuser")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "philuser", users[1].Name)
|
||||
require.Equal(t, RoleUser, users[1].Role)
|
||||
require.Equal(t, 2, len(grants))
|
||||
require.Equal(t, "secret12", grants[0].TopicPattern)
|
||||
require.Equal(t, PermissionRead, grants[0].Permission)
|
||||
require.Equal(t, "stats12", grants[1].TopicPattern)
|
||||
require.Equal(t, PermissionReadWrite, grants[1].Permission)
|
||||
|
||||
require.Equal(t, "*", users[2].Name)
|
||||
tokens, err = a.Tokens(provisionedUserID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(tokens))
|
||||
require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value)
|
||||
require.Equal(t, "Alerts token updated", tokens[0].Label)
|
||||
require.True(t, tokens[0].Provisioned)
|
||||
|
||||
// Re-open the DB again (third app start)
|
||||
require.Nil(t, a.db.Close())
|
||||
conf.Users = []*User{}
|
||||
conf.Access = map[string][]*Grant{}
|
||||
conf.Tokens = map[string][]*Token{}
|
||||
a, err = NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Check that the provisioned users are there
|
||||
// Check that the provisioned users are all gone
|
||||
users, err = a.Users()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, users, 2)
|
||||
|
|
@ -1191,6 +1212,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||
require.Equal(t, "philmanual", users[0].Name)
|
||||
require.Equal(t, RoleUser, users[0].Role)
|
||||
require.Equal(t, "*", users[1].Name)
|
||||
|
||||
grants, err = a.Grants("philuser")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(grants))
|
||||
|
||||
tokens, err = a.Tokens(provisionedUserID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(tokens))
|
||||
}
|
||||
|
||||
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"github.com/stripe/stripe-go/v74"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
|
@ -59,11 +60,12 @@ type Auther interface {
|
|||
|
||||
// Token represents a user token, including expiry date
|
||||
type Token struct {
|
||||
Value string
|
||||
Label string
|
||||
LastAccess time.Time
|
||||
LastOrigin netip.Addr
|
||||
Expires time.Time
|
||||
Value string
|
||||
Label string
|
||||
LastAccess time.Time
|
||||
LastOrigin netip.Addr
|
||||
Expires time.Time
|
||||
Provisioned bool
|
||||
}
|
||||
|
||||
// TokenUpdate holds information about the last access time and origin IP address of a token
|
||||
|
|
@ -247,6 +249,7 @@ var (
|
|||
allowedTopicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No '*'
|
||||
allowedTopicPatternRegex = regexp.MustCompile(`^[-_*A-Za-z0-9]{1,64}$`) // Adds '*' for wildcards!
|
||||
allowedTierRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)
|
||||
allowedTokenRegex = regexp.MustCompile(`^tk_[-_A-Za-z0-9]{29}$`) // Must be tokenLength-len(tokenPrefix)
|
||||
)
|
||||
|
||||
// AllowedRole returns true if the given role can be used for new users
|
||||
|
|
@ -282,6 +285,17 @@ func AllowedPasswordHash(hash string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// AllowedToken returns true if the given token matches the naming convention
|
||||
func AllowedToken(token string) bool {
|
||||
return allowedTokenRegex.MatchString(token)
|
||||
}
|
||||
|
||||
// GenerateToken generates a new token with a prefix and a fixed length
|
||||
// Lowercase only to support "<topic>+<token>@<domain>" email addresses
|
||||
func GenerateToken() string {
|
||||
return util.RandomLowerStringPrefix(tokenPrefix, tokenLength)
|
||||
}
|
||||
|
||||
// Error constants used by the package
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue