1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-12-03 05:40:09 +01:00
This commit is contained in:
binwiederhier 2025-07-21 17:44:00 +02:00
parent 51af114b2e
commit f59df0f40a
13 changed files with 333 additions and 140 deletions

View file

@ -12,6 +12,7 @@ import (
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
"net/netip"
"path/filepath"
"strings"
"sync"
"time"
@ -75,6 +76,7 @@ const (
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
provisioned INT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
@ -97,6 +99,7 @@ const (
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
provisioned INT NOT NULL,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
@ -121,8 +124,8 @@ const (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
COMMIT;
`
@ -132,26 +135,26 @@ const (
`
selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
@ -165,8 +168,8 @@ const (
`
insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
@ -189,18 +192,18 @@ const (
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
selectUserAllAccessQuery = `
SELECT user_id, topic, read, write
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
selectUserAccessQuery = `
SELECT topic, read, write
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
@ -244,7 +247,8 @@ const (
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
deleteTopicAccessQuery = `
deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
deleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
@ -427,6 +431,15 @@ const (
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
// 5 -> 6
migrate5To6UpdateQueries = `
ALTER TABLE user ADD COLUMN provisioned INT NOT NULL DEFAULT (0);
ALTER TABLE user ALTER COLUMN provisioned DROP DEFAULT;
ALTER TABLE user_access ADD COLUMN provisioned INT NOT NULL DEFAULT (0);
ALTER TABLE user_access ALTER COLUMN provisioned DROP DEFAULT;
`
)
var (
@ -435,6 +448,7 @@ var (
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
}
)
@ -452,8 +466,9 @@ type Config struct {
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db"
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers
DefaultAccess Permission // Default permission if no ACL matches
ProvisionedUsers []*User // Predefined users to create on startup
ProvisionedAccess map[string][]*Grant // Predefined access grants to create on startup
ProvisionEnabled bool // Enable auto-provisioning of users and access grants
ProvisionUsers []*User // Predefined users to create on startup
ProvisionAccess map[string][]*Grant // Predefined access grants to create on startup
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
BcryptCost int // Cost of generated passwords; lowering makes testing faster
}
@ -469,6 +484,11 @@ func NewManager(config *Config) (*Manager, error) {
if config.QueueWriterInterval.Seconds() <= 0 {
config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
}
// Check the parent directory of the database file (makes for friendly error messages)
parentDir := filepath.Dir(config.Filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
// Open DB and run setup queries
db, err := sql.Open("sqlite3", config.Filename)
if err != nil {
@ -486,7 +506,7 @@ func NewManager(config *Config) (*Manager, error) {
statsQueue: make(map[string]*Stats),
tokenQueue: make(map[string]*TokenUpdate),
}
if err := manager.provisionUsers(); err != nil {
if err := manager.maybeProvisionUsersAndAccess(); err != nil {
return nil, err
}
go manager.asyncQueueWriter(config.QueueWriterInterval)
@ -586,7 +606,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
tokens := make([]*Token, 0)
for {
token, err := a.readToken(rows)
if err == ErrTokenNotFound {
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
@ -884,6 +904,13 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, password, role, hashed, false)
})
}
// AddUser adds a user with the given username, password and role
func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
@ -899,8 +926,8 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
}
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil {
if errors.Is(err, sqlite3.ErrConstraintUnique) {
return ErrUserExists
}
return err
@ -911,11 +938,17 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
// RemoveUser deletes the user with the given username. The function returns nil on success, even
// if the user did not exist in the first place.
func (a *Manager) RemoveUser(username string) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username)
})
}
func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
// Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
if _, err := tx.Exec(deleteUserQuery, username); err != nil {
return err
}
return nil
@ -1029,24 +1062,26 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var provisioned bool
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
user := &User{
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
Provisioned: provisioned,
Stats: &Stats{
Messages: messages,
Emails: emails,
@ -1097,8 +1132,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
grants := make(map[string][]Grant, 0)
for rows.Next() {
var userID, topic string
var read, write bool
if err := rows.Scan(&userID, &topic, &read, &write); err != nil {
var read, write, provisioned bool
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -1108,7 +1143,8 @@ func (a *Manager) AllGrants() (map[string][]Grant, error) {
}
grants[userID] = append(grants[userID], Grant{
TopicPattern: fromSQLWildcard(topic),
Allow: NewPermission(read, write),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
@ -1124,15 +1160,16 @@ func (a *Manager) Grants(username string) ([]Grant, error) {
grants := make([]Grant, 0)
for rows.Next() {
var topic string
var read, write bool
if err := rows.Scan(&topic, &read, &write); err != nil {
var read, write, provisioned bool
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
grants = append(grants, Grant{
TopicPattern: fromSQLWildcard(topic),
Allow: NewPermission(read, write),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
@ -1218,9 +1255,14 @@ func (a *Manager) ReservationOwner(topic string) (string, error) {
// ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string, hashed bool) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordTx(tx, username, password, hashed)
})
}
func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error {
var hash []byte
var err error
if hashed {
hash = []byte(password)
} else {
@ -1229,7 +1271,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
return err
}
}
if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil {
return err
}
return nil
@ -1238,14 +1280,20 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
// ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
// all existing access control entries (Grant) are removed, since they are no longer needed.
func (a *Manager) ChangeRole(username string, role Role) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role)
})
}
func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil {
return err
}
if role == RoleAdmin {
if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil {
return err
}
}
@ -1325,13 +1373,19 @@ func (a *Manager) AllowReservation(username string, topic string) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return execTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false)
})
}
func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
owner := ""
if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil {
return err
}
return nil
@ -1524,20 +1578,65 @@ func (a *Manager) Close() error {
return a.db.Close()
}
func (a *Manager) provisionUsers() error {
for _, user := range a.config.ProvisionedUsers {
if err := a.AddUser(user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) {
return err
}
func (a *Manager) maybeProvisionUsersAndAccess() error {
if !a.config.ProvisionEnabled {
return nil
}
for username, grants := range a.config.ProvisionedAccess {
for _, grant := range grants {
if err := a.AllowAccess(username, grant.TopicPattern, grant.Allow); err != nil {
return err
users, err := a.Users()
if err != nil {
return err
}
provisionUsernames := util.Map(a.config.ProvisionUsers, func(u *User) string {
return u.Name
})
return execTx(a.db, func(tx *sql.Tx) error {
// Remove users that are provisioned, but not in the config anymore
for _, user := range users {
if user.Name == Everyone {
continue
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
log.Tag(tag).Info("Removing previously provisioned user %s", user.Name)
if err := a.removeUserTx(tx, user.Name); err != nil {
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
}
}
}
}
return nil
// Add or update provisioned users
for _, user := range a.config.ProvisionUsers {
if user.Name == Everyone {
continue
}
existingUser, exists := util.Find(users, func(u *User) bool {
return u.Name == user.Name
})
if !exists {
log.Tag(tag).Info("Adding provisioned user %s", user.Name)
if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
}
} else if existingUser.Hash != user.Hash || existingUser.Role != user.Role {
log.Tag(tag).Info("Updating provisioned user %s", user.Name)
if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
}
if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
}
}
}
// Remove and (re-)add provisioned grants
if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
return err
}
for username, grants := range a.config.ProvisionAccess {
for _, grant := range grants {
if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
return err
}
}
}
return nil
})
}
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
@ -1711,6 +1810,22 @@ func migrateFrom4(db *sql.DB) error {
return tx.Commit()
}
func migrateFrom5(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
@ -1724,3 +1839,18 @@ func nullInt64(v int64) sql.NullInt64 {
}
return sql.NullInt64{Int64: v, Valid: true}
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
if err := f(tx); err != nil {
if e := tx.Rollback(); e != nil {
return err
}
return err
}
return tx.Commit()
}