diff --git a/user/manager.go b/user/manager.go index 7f3b8b1d..324b7684 100644 --- a/user/manager.go +++ b/user/manager.go @@ -160,7 +160,7 @@ const ( SELECT read, write FROM user_access a JOIN user u ON u.id = a.user_id - WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic + WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\' ORDER BY u.user DESC ` @@ -235,7 +235,7 @@ const ( selectOtherAccessCountQuery = ` SELECT COUNT(*) FROM user_access - WHERE (topic = ? OR ? LIKE topic) + WHERE (topic = ? OR ? LIKE topic ESCAPE '\') AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) ` deleteAllAccessQuery = `DELETE FROM user_access` @@ -312,7 +312,7 @@ const ( // Schema management queries const ( - currentSchemaVersion = 4 + currentSchemaVersion = 5 insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` @@ -422,6 +422,11 @@ const ( FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE ); ` + + // 4 -> 5 + migrate4To5UpdateQueries = ` + UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); + ` ) var ( @@ -429,6 +434,7 @@ var ( 1: migrateFrom1, 2: migrateFrom2, 3: migrateFrom3, + 4: migrateFrom4, } ) @@ -1123,7 +1129,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { return nil, err } reservations = append(reservations, Reservation{ - Topic: topic, + Topic: unescapeUnderscore(topic), Owner: NewPermission(ownerRead, ownerWrite), Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null }) @@ -1133,7 +1139,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) { // HasReservation returns true if the given topic access is owned by the user func (a *Manager) HasReservation(username, topic string) (bool, error) { - rows, err := a.db.Query(selectUserHasReservationQuery, username, topic) + rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic)) if err != nil { return false, err } @@ -1168,7 +1174,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) { // ReservationOwner returns user ID of the user that owns this topic, or an // empty string if it's not owned by anyone func (a *Manager) ReservationOwner(topic string) (string, error) { - rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic) + rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic)) if err != nil { return "", err } @@ -1263,7 +1269,7 @@ func (a *Manager) AllowReservation(username string, topic string) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { return ErrInvalidArgument } - rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username) + rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username) if err != nil { return err } @@ -1328,10 +1334,10 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss return err } defer tx.Rollback() - if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil { + if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username); err != nil { return err } - if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { + if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { return err } return tx.Commit() @@ -1354,10 +1360,10 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { } defer tx.Rollback() for _, topic := range topics { - if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil { + if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil { return err } - if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil { + if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil { return err } } @@ -1484,12 +1490,24 @@ func (a *Manager) Close() error { return a.db.Close() } +// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, +// and escapes '_', assuming '\' as escape character. func toSQLWildcard(s string) string { - return strings.ReplaceAll(s, "*", "%") + return escapeUnderscore(strings.ReplaceAll(s, "*", "%")) } +// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*', +// and removes the '\_' escape character. func fromSQLWildcard(s string) string { - return strings.ReplaceAll(s, "%", "*") + return strings.ReplaceAll(unescapeUnderscore(s), "%", "*") +} + +func escapeUnderscore(s string) string { + return strings.ReplaceAll(s, "_", "\\_") +} + +func unescapeUnderscore(s string) string { + return strings.ReplaceAll(s, "\\_", "_") } func runStartupQueries(db *sql.DB, startupQueries string) error { @@ -1627,6 +1645,22 @@ func migrateFrom3(db *sql.DB) error { return tx.Commit() } +func migrateFrom4(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 4 to 5") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { + return err + } + return tx.Commit() +} + func nullString(s string) sql.NullString { if s == "" { return sql.NullString{} diff --git a/user/manager_test.go b/user/manager_test.go index 3c30a716..468dc36a 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -330,7 +330,7 @@ func TestManager_Reservations(t *testing.T) { a := newTestManager(t, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleUser)) require.Nil(t, a.AddUser("ben", "ben", RoleUser)) - require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll)) + require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) @@ -343,7 +343,7 @@ func TestManager_Reservations(t *testing.T) { Everyone: PermissionRead, }, reservations[0]) require.Equal(t, Reservation{ - Topic: "ztopic", + Topic: "ztopic_", Owner: PermissionReadWrite, Everyone: PermissionDenyAll, }, reservations[1]) @@ -352,6 +352,14 @@ func TestManager_Reservations(t *testing.T) { require.Nil(t, err) require.True(t, b) + b, err = a.HasReservation("ben", "ztopic_") + require.Nil(t, err) + require.True(t, b) + + b, err = a.HasReservation("ben", "ztopicX") // _ != X (used to be a SQL wildcard issue) + require.Nil(t, err) + require.False(t, b) + b, err = a.HasReservation("notben", "readme") require.Nil(t, err) require.False(t, b) @@ -371,11 +379,17 @@ func TestManager_Reservations(t *testing.T) { err = a.AllowReservation("phil", "readme") require.Equal(t, errTopicOwnedByOthers, err) + err = a.AllowReservation("phil", "ztopic_") + require.Equal(t, errTopicOwnedByOthers, err) + + err = a.AllowReservation("phil", "ztopicX") + require.Nil(t, err) + err = a.AllowReservation("phil", "not-reserved") require.Nil(t, err) // Now remove them again - require.Nil(t, a.RemoveReservations("ben", "ztopic", "readme")) + require.Nil(t, a.RemoveReservations("ben", "ztopic_", "readme")) count, err = a.ReservationsCount("ben") require.Nil(t, err) @@ -978,7 +992,44 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) { require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890")) } -func TestSqliteCache_Migration_From1(t *testing.T) { +func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) { + f := filepath.Join(t.TempDir(), "user.db") + a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) + require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead)) + require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead)) + require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead)) + require.Nil(t, a.Authorize(nil, "__allowed_", PermissionRead)) + require.Nil(t, a.Authorize(nil, "_allowed_", PermissionRead)) // The "%" in "%\_" matches the first "_" + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead)) +} + +func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) { + f := filepath.Join(t.TempDir(), "user.db") + a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) + require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite)) + require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead)) + require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite)) +} + +func TestToFromSQLWildcard(t *testing.T) { + require.Equal(t, "up%", toSQLWildcard("up*")) + require.Equal(t, "up\\_%", toSQLWildcard("up_*")) + require.Equal(t, "foo", toSQLWildcard("foo")) + + require.Equal(t, "up*", fromSQLWildcard("up%")) + require.Equal(t, "up_*", fromSQLWildcard("up\\_%")) + require.Equal(t, "foo", fromSQLWildcard("foo")) + + require.Equal(t, "up*", fromSQLWildcard(toSQLWildcard("up*"))) + require.Equal(t, "up_*", fromSQLWildcard(toSQLWildcard("up_*"))) + require.Equal(t, "foo", fromSQLWildcard(toSQLWildcard("foo"))) +} + +func TestMigrationFrom1(t *testing.T) { filename := filepath.Join(t.TempDir(), "user.db") db, err := sql.Open("sqlite3", filename) require.Nil(t, err) @@ -1063,6 +1114,152 @@ func TestSqliteCache_Migration_From1(t *testing.T) { require.Equal(t, PermissionRead, everyoneGrants[0].Allow) } +func TestMigrationFrom4(t *testing.T) { + filename := filepath.Join(t.TempDir(), "user.db") + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 4" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + calls_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; + INSERT INTO schemaVersion (id, version) VALUES (1, 4); + COMMIT; + `) + require.Nil(t, err) + + // Insert a few ACL entries + _, err = db.Exec(` + BEGIN; + INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'mytopic_', 1, 1); + INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'up%', 1, 1); + INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'down_%', 1, 1); + COMMIT; + `) + require.Nil(t, err) + + // Create manager to trigger migration + a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval) + checkSchemaVersion(t, a.db) + + // Add another + require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite)) + + // Check "external view" of grants + everyoneGrants, err := a.Grants(Everyone) + require.Nil(t, err) + + require.Equal(t, 4, len(everyoneGrants)) + require.Equal(t, "down_*", everyoneGrants[0].TopicPattern) + require.Equal(t, "left_*", everyoneGrants[1].TopicPattern) + require.Equal(t, "mytopic_", everyoneGrants[2].TopicPattern) + require.Equal(t, "up*", everyoneGrants[3].TopicPattern) + + // Check they are stored correctly in the database + rows, err := db.Query(`SELECT topic FROM user_access WHERE user_id = 'u_everyone' ORDER BY topic`) + require.Nil(t, err) + topicPatterns := make([]string, 0) + for rows.Next() { + var topicPattern string + require.Nil(t, rows.Scan(&topicPattern)) + topicPatterns = append(topicPatterns, topicPattern) + } + require.Nil(t, rows.Close()) + require.Equal(t, 4, len(topicPatterns)) + require.Equal(t, "down\\_%", topicPatterns[0]) + require.Equal(t, "left\\_%", topicPatterns[1]) + require.Equal(t, "mytopic\\_", topicPatterns[2]) + require.Equal(t, "up%", topicPatterns[3]) + + // Check that ACL works as excepted + require.Nil(t, a.Authorize(nil, "down_123", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "downX123", PermissionRead)) + + require.Nil(t, a.Authorize(nil, "left_abc", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "leftX123", PermissionRead)) + + require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead)) + require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead)) + + require.Nil(t, a.Authorize(nil, "up123", PermissionRead)) + require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters +} + func checkSchemaVersion(t *testing.T, db *sql.DB) { rows, err := db.Query(`SELECT version FROM schemaVersion`) require.Nil(t, err)