1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-23 21:08:05 +02:00

Startup queries, foreign keys

This commit is contained in:
binwiederhier 2023-01-05 15:20:44 -05:00
parent 3280c2c440
commit 60f1882bec
14 changed files with 148 additions and 69 deletions

View file

@ -59,7 +59,8 @@ const (
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id INT NOT NULL,
@ -75,6 +76,10 @@ const (
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.messages, u.emails, u.settings, p.code, p.messages_limit, p.emails_limit, p.topics_limit, p.attachment_file_size_limit, p.attachment_total_size_limit
FROM user u
@ -95,10 +100,7 @@ const (
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
ORDER BY u.user DESC
`
)
// Manager-related queries
const (
insertUserQuery = `INSERT INTO user (user, pass, role) VALUES (?, ?, ?)`
selectUsernamesQuery = `
SELECT user
@ -150,7 +152,6 @@ const (
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
deleteUserTokensQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?)`
)
// Schema management queries
@ -191,12 +192,12 @@ type Manager struct {
var _ Auther = (*Manager)(nil)
// NewManager creates a new Manager instance
func NewManager(filename string, defaultAccess Permission) (*Manager, error) {
return newManager(filename, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
return newManager(filename, startupQueries, defaultAccess, userTokenExpiryDuration, userStatsQueueWriterInterval)
}
// NewManager creates a new Manager instance
func newManager(filename string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) {
func newManager(filename, startupQueries string, defaultAccess Permission, tokenExpiryDuration, statsWriterInterval time.Duration) (*Manager, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@ -204,6 +205,9 @@ func newManager(filename string, defaultAccess Permission, tokenExpiryDuration,
if err := setupDB(db); err != nil {
return nil, err
}
if err := runStartupQueries(db, startupQueries); err != nil {
return nil, err
}
manager := &Manager{
db: db,
defaultAccess: defaultAccess,
@ -223,11 +227,12 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
}
user, err := a.User(username)
if err != nil {
bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash),
[]byte("intentional slow-down to avoid timing attacks"))
log.Trace("authentication of user %s failed (1): %s", username, err.Error())
bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Trace("authentication of user %s failed (2): %s", username, err.Error())
return nil, ErrUnauthenticated
}
return user, nil
@ -407,21 +412,11 @@ func (a *Manager) RemoveUser(username string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
// Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(deleteUserAccessQuery, username); err != nil {
return err
}
if _, err := tx.Exec(deleteUserTokensQuery, username); err != nil {
return err
}
if _, err := tx.Exec(deleteUserQuery, username); err != nil {
return err
}
return tx.Commit()
return nil
}
// Users returns a list of users. It always also returns the Everyone user ("*").
@ -666,6 +661,16 @@ func fromSQLWildcard(s string) string {
return strings.ReplaceAll(s, "%", "*")
}
func runStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
func setupDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rowsSV, err := db.Query(selectSchemaVersionQuery)