diff --git a/server/server.go b/server/server.go index 8bd19727..05376e6c 100644 --- a/server/server.go +++ b/server/server.go @@ -40,7 +40,6 @@ import ( message cache duration Keep 10000 messages or keep X days? Attachment expiration based on plan - database migration reserve topics purge accounts that were not logged into in X reset daily limits for users diff --git a/user/manager.go b/user/manager.go index a5f6a370..c4d4de8d 100644 --- a/user/manager.go +++ b/user/manager.go @@ -24,8 +24,7 @@ const ( // Manager-related queries const ( - createAuthTablesQueries = ` - BEGIN; + createTablesQueriesNoTx = ` CREATE TABLE IF NOT EXISTS plan ( id INT NOT NULL, code TEXT NOT NULL, @@ -67,8 +66,8 @@ const ( version INT NOT NULL ); INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING; - COMMIT; ` + createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;` selectUserByNameQuery = ` SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.attachment_file_size_limit, p.attachment_total_size_limit FROM user u @@ -130,9 +129,27 @@ const ( // Schema management queries const ( - currentSchemaVersion = 1 + currentSchemaVersion = 2 insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` + + // 1 -> 2 (complex migration!) + migrate1To2RenameUserTableQueryNoTx = ` + ALTER TABLE user RENAME TO user_old; + ` + migrate1To2InsertFromOldTablesAndDropNoTx = ` + INSERT INTO user (user, pass, role) + SELECT user, pass, role FROM user_old; + + INSERT INTO user_access (user_id, topic, read, write) + SELECT u.id, a.topic, a.read, a.write + FROM user u + JOIN access a ON u.user = a.user; + + DROP TABLE access; + DROP TABLE user_old; + ` ) // Manager is an implementation of Manager. It stores users and access control list @@ -159,7 +176,7 @@ func newManager(filename string, defaultRead, defaultWrite bool, tokenExpiryDura if err != nil { return nil, err } - if err := setupAuthDB(db); err != nil { + if err := setupDB(db); err != nil { return nil, err } manager := &Manager{ @@ -364,16 +381,21 @@ func (a *Manager) RemoveUser(username string) error { if !AllowedUsername(username) { return ErrInvalidArgument } - if _, err := a.db.Exec(deleteUserAccessQuery, username); err != nil { + tx, err := a.db.Begin() + if err != nil { return err } - if _, err := a.db.Exec(deleteUserTokensQuery, username); err != nil { + defer tx.Rollback() + if _, err := tx.Exec(deleteUserAccessQuery, username); err != nil { return err } - if _, err := a.db.Exec(deleteUserQuery, username); err != nil { + if _, err := tx.Exec(deleteUserTokensQuery, username); err != nil { return err } - return nil + if _, err := tx.Exec(deleteUserQuery, username); err != nil { + return err + } + return tx.Commit() } // Users returns a list of users. It always also returns the Everyone user ("*"). @@ -567,11 +589,11 @@ func fromSQLWildcard(s string) string { return strings.ReplaceAll(s, "%", "*") } -func setupAuthDB(db *sql.DB) error { +func setupDB(db *sql.DB) error { // If 'schemaVersion' table does not exist, this must be a new database rowsSV, err := db.Query(selectSchemaVersionQuery) if err != nil { - return setupNewAuthDB(db) + return setupNewDB(db) } defer rowsSV.Close() @@ -588,12 +610,14 @@ func setupAuthDB(db *sql.DB) error { // Do migrations if schemaVersion == currentSchemaVersion { return nil + } else if schemaVersion == 1 { + return migrateFrom1(db) } return fmt.Errorf("unexpected schema version found: %d", schemaVersion) } -func setupNewAuthDB(db *sql.DB) error { - if _, err := db.Exec(createAuthTablesQueries); err != nil { +func setupNewDB(db *sql.DB) error { + if _, err := db.Exec(createTablesQueries); err != nil { return err } if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { @@ -601,3 +625,28 @@ func setupNewAuthDB(db *sql.DB) error { } return nil } + +func migrateFrom1(db *sql.DB) error { + log.Info("Migrating user database schema: from 1 to 2") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil { + return err + } + if _, err := tx.Exec(createTablesQueriesNoTx); err != nil { + return err + } + if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil // Update this when a new version is added +} diff --git a/user/manager_test.go b/user/manager_test.go index c3a7e5e1..c9669ad1 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -1,6 +1,7 @@ package user import ( + "database/sql" "github.com/stretchr/testify/require" "path/filepath" "strings" @@ -350,8 +351,95 @@ func TestManager_EnqueueStats(t *testing.T) { require.Equal(t, int64(2), u.Stats.Emails) } +func TestSqliteCache_Migration_From1(t *testing.T) { + filename := filepath.Join(t.TempDir(), "user.db") + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 1" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS user ( + user TEXT NOT NULL PRIMARY KEY, + pass TEXT NOT NULL, + role TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS access ( + user TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + PRIMARY KEY (topic, user) + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schemaVersion (id, version) VALUES (1, 1); + COMMIT; + `) + require.Nil(t, err) + + // Insert a bunch of users and ACL entries + _, err = db.Exec(` + BEGIN; + INSERT INTO user (user, pass, role) VALUES ('ben', '$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy', 'user'); + INSERT INTO user (user, pass, role) VALUES ('phil', '$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C', 'admin'); + INSERT INTO access (user, topic, read, write) VALUES ('ben', 'stats', 1, 1); + INSERT INTO access (user, topic, read, write) VALUES ('ben', 'secret', 1, 0); + INSERT INTO access (user, topic, read, write) VALUES ('*', 'stats', 1, 0); + COMMIT; + `) + require.Nil(t, err) + + // Create manager to trigger migration + a := newTestManagerFromFile(t, filename, false, false, userTokenExpiryDuration, userStatsQueueWriterInterval) + checkSchemaVersion(t, a.db) + + users, err := a.Users() + require.Nil(t, err) + require.Equal(t, 3, len(users)) + phil, ben, everyone := users[0], users[1], users[2] + + require.Equal(t, "phil", phil.Name) + require.Equal(t, RoleAdmin, phil.Role) + require.Equal(t, 0, len(phil.Grants)) + + require.Equal(t, "ben", ben.Name) + require.Equal(t, RoleUser, ben.Role) + require.Equal(t, 2, len(ben.Grants)) + require.Equal(t, "stats", ben.Grants[0].TopicPattern) + require.Equal(t, true, ben.Grants[0].AllowRead) + require.Equal(t, true, ben.Grants[0].AllowWrite) + require.Equal(t, "secret", ben.Grants[1].TopicPattern) + require.Equal(t, true, ben.Grants[1].AllowRead) + require.Equal(t, false, ben.Grants[1].AllowWrite) + + require.Equal(t, Everyone, everyone.Name) + require.Equal(t, RoleAnonymous, everyone.Role) + require.Equal(t, 1, len(everyone.Grants)) + require.Equal(t, "stats", everyone.Grants[0].TopicPattern) + require.Equal(t, true, everyone.Grants[0].AllowRead) + require.Equal(t, false, everyone.Grants[0].AllowWrite) +} + +func checkSchemaVersion(t *testing.T, db *sql.DB) { + rows, err := db.Query(`SELECT version FROM schemaVersion`) + require.Nil(t, err) + require.True(t, rows.Next()) + + var schemaVersion int + require.Nil(t, rows.Scan(&schemaVersion)) + require.Equal(t, currentSchemaVersion, schemaVersion) + require.Nil(t, rows.Close()) +} + func newTestManager(t *testing.T, defaultRead, defaultWrite bool) *Manager { - a, err := NewManager(filepath.Join(t.TempDir(), "db"), defaultRead, defaultWrite) + return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), defaultRead, defaultWrite, userTokenExpiryDuration, userStatsQueueWriterInterval) +} + +func newTestManagerFromFile(t *testing.T, filename string, defaultRead, defaultWrite bool, tokenExpiryDuration, statsWriterInterval time.Duration) *Manager { + a, err := newManager(filename, defaultRead, defaultWrite, tokenExpiryDuration, statsWriterInterval) require.Nil(t, err) return a }