From d499d20a9cdc15b2f1dc44d4e2c35e75acb275a2 Mon Sep 17 00:00:00 2001
From: Philipp Heckel <pheckel@datto.com>
Date: Sat, 3 Dec 2022 15:20:59 -0500
Subject: [PATCH] Token stuff

---
 auth/auth.go                   |  15 ++--
 auth/auth_sqlite.go            |  63 +++++++++++++---
 server/server.go               | 128 +++++++++++++++++++++------------
 server/server_firebase_test.go |   2 +-
 web/src/app/Api.js             |  20 +++++-
 web/src/app/utils.js           |   9 ++-
 web/src/components/App.js      |  19 +++++
 web/src/components/Login.js    |   2 +-
 8 files changed, 194 insertions(+), 64 deletions(-)

diff --git a/auth/auth.go b/auth/auth.go
index 35b910c5..93a0ecf4 100644
--- a/auth/auth.go
+++ b/auth/auth.go
@@ -6,13 +6,17 @@ import (
 	"regexp"
 )
 
-// Auther is a generic interface to implement password-based authentication and authorization
+// Auther is a generic interface to implement password and token based authentication and authorization
 type Auther interface {
 	// Authenticate checks username and password and returns a user if correct. The method
 	// returns in constant-ish time, regardless of whether the user exists or the password is
 	// correct or incorrect.
 	Authenticate(username, password string) (*User, error)
 
+	AuthenticateToken(token string) (*User, error)
+
+	GenerateToken(user *User) (string, error)
+
 	// Authorize returns nil if the given user has access to the given topic using the desired
 	// permission. The user param may be nil to signal an anonymous user.
 	Authorize(user *User, topic string, perm Permission) error
@@ -56,10 +60,11 @@ type Manager interface {
 
 // User is a struct that represents a user
 type User struct {
-	Name   string
-	Hash   string // password hash (bcrypt)
-	Role   Role
-	Grants []Grant
+	Name     string
+	Hash     string // password hash (bcrypt)
+	Role     Role
+	Grants   []Grant
+	Language string
 }
 
 // Grant is a struct that represents an access control entry to a topic
diff --git a/auth/auth_sqlite.go b/auth/auth_sqlite.go
index a91c45ef..4ed89e3b 100644
--- a/auth/auth_sqlite.go
+++ b/auth/auth_sqlite.go
@@ -6,10 +6,12 @@ import (
 	"fmt"
 	_ "github.com/mattn/go-sqlite3" // SQLite driver
 	"golang.org/x/crypto/bcrypt"
+	"heckel.io/ntfy/util"
 	"strings"
 )
 
 const (
+	tokenLength             = 32
 	bcryptCost              = 10
 	intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
 )
@@ -67,7 +69,17 @@ const (
 		INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
 		COMMIT;
 	`
-	selectUserQuery       = `SELECT pass, role FROM user WHERE user = ?`
+	selectUserByNameQuery = `
+		SELECT user, pass, role, language 
+		FROM user 
+		WHERE user = ?
+	`
+	selectUserByTokenQuery = `
+		SELECT user, pass, role, language 
+		FROM user
+		JOIN user_token on user.id = user_token.user_id
+		WHERE token = ?
+	`
 	selectTopicPermsQuery = `
 		SELECT read, write 
 		FROM user_access
@@ -90,6 +102,8 @@ const (
 	deleteAllAccessQuery   = `DELETE FROM user_access`
 	deleteUserAccessQuery  = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
 	deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
+
+	insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
 )
 
 // Schema management queries
@@ -126,7 +140,7 @@ func NewSQLiteAuth(filename string, defaultRead, defaultWrite bool) (*SQLiteAuth
 	}, nil
 }
 
-// Authenticate checks username and password and returns a user if correct. The method
+// AuthenticateUser checks username and password and returns a user if correct. The method
 // returns in constant-ish time, regardless of whether the user exists or the password is
 // correct or incorrect.
 func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
@@ -145,6 +159,23 @@ func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
 	return user, nil
 }
 
+func (a *SQLiteAuth) AuthenticateToken(token string) (*User, error) {
+	user, err := a.userByToken(token)
+	if err != nil {
+		return nil, ErrUnauthenticated
+	}
+	return user, nil
+}
+
+func (a *SQLiteAuth) GenerateToken(user *User) (string, error) {
+	token := util.RandomString(tokenLength)
+	expires := 1 // FIXME
+	if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil {
+		return "", err
+	}
+	return token, nil
+}
+
 // Authorize returns nil if the given user has access to the given topic using the desired
 // permission. The user param may be nil to signal an anonymous user.
 func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error {
@@ -255,16 +286,29 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
 	if username == Everyone {
 		return a.everyoneUser()
 	}
-	rows, err := a.db.Query(selectUserQuery, username)
+	rows, err := a.db.Query(selectUserByNameQuery, username)
 	if err != nil {
 		return nil, err
 	}
+	return a.readUser(rows)
+}
+
+func (a *SQLiteAuth) userByToken(token string) (*User, error) {
+	rows, err := a.db.Query(selectUserByTokenQuery, token)
+	if err != nil {
+		return nil, err
+	}
+	return a.readUser(rows)
+}
+
+func (a *SQLiteAuth) readUser(rows *sql.Rows) (*User, error) {
 	defer rows.Close()
-	var hash, role string
+	var username, hash, role string
+	var language sql.NullString
 	if !rows.Next() {
 		return nil, ErrNotFound
 	}
-	if err := rows.Scan(&hash, &role); err != nil {
+	if err := rows.Scan(&username, &hash, &role, &language); err != nil {
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 		return nil, err
@@ -274,10 +318,11 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
 		return nil, err
 	}
 	return &User{
-		Name:   username,
-		Hash:   hash,
-		Role:   Role(role),
-		Grants: grants,
+		Name:     username,
+		Hash:     hash,
+		Role:     Role(role),
+		Grants:   grants,
+		Language: language.String,
 	}, nil
 }
 
diff --git a/server/server.go b/server/server.go
index ec573878..75f49d6f 100644
--- a/server/server.go
+++ b/server/server.go
@@ -7,6 +7,7 @@ import (
 	"embed"
 	"encoding/base64"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 	} else if r.Method == http.MethodOptions {
 		return s.ensureWebEnabled(s.handleOptions)(w, r, v)
 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
-		return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v)
+		return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
-		return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v)
+		return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
+		return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
 	} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
+		return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
 	} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
+		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
 	} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
+		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
 	} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
+		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
 	} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
+		return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
 	} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
-		return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
+		return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
 	} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
 		return s.ensureWebEnabled(s.handleTopic)(w, r, v)
 	}
@@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
 	return nil
 }
 
-var sessions = make(map[string]*auth.User) // token-> user
-
 type tokenAuthResponse struct {
 	Token string `json:"token"`
 }
@@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
 	if v.user == nil {
 		return errHTTPUnauthorized
 	}
-	token := util.RandomString(32)
-	sessions[token] = v.user
+	token, err := s.auth.GenerateToken(v.user)
+	if err != nil {
+		return err
+	}
 	w.Header().Set("Content-Type", "text/json")
 	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
 	response := &tokenAuthResponse{
@@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
 	Topic   string `json:"topic"`
 }
 
+type userNotificationSettingsResponse struct {
+	Sound       string `json:"sound"`
+	MinPriority string `json:"min_priority"`
+	DeleteAfter int    `json:"delete_after"`
+}
+
+type userPlanResponse struct {
+	Id   int    `json:"id"`
+	Name string `json:"name"`
+}
+
 type userAccountResponse struct {
-	Username string `json:"username"`
-	Role     string `json:"role,omitempty"`
-	Language string `json:"language,omitempty"`
-	Plan     struct {
-		Id   int    `json:"id"`
-		Name string `json:"name"`
-	} `json:"plan,omitempty"`
-	Notification struct {
-		Sound       string `json:"sound"`
-		MinPriority string `json:"min_priority"`
-		DeleteAfter int    `json:"delete_after"`
-	} `json:"notification,omitempty"`
-	Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
+	Username      string                            `json:"username"`
+	Role          string                            `json:"role,omitempty"`
+	Language      string                            `json:"language,omitempty"`
+	Plan          *userPlanResponse                 `json:"plan,omitempty"`
+	Notification  *userNotificationSettingsResponse `json:"notification,omitempty"`
+	Subscriptions []*userSubscriptionResponse       `json:"subscriptions,omitempty"`
 }
 
 func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	w.Header().Set("Content-Type", "text/json")
 	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
-	var response *userAccountResponse
+	response := &userAccountResponse{}
 	if v.user != nil {
-		response = &userAccountResponse{
-			Username: v.user.Name,
-			Role:     string(v.user.Role),
-			Language: "en_US",
+		response.Username = v.user.Name
+		response.Role = string(v.user.Role)
+		response.Language = v.user.Language
+		response.Notification = &userNotificationSettingsResponse{
+			Sound: "dadum",
 		}
 	} else {
 		response = &userAccountResponse{
-			Username: "anonymous",
+			Username: auth.Everyone,
+			Role:     string(auth.RoleAnonymous),
 		}
 	}
 	if err := json.NewEncoder(w).Encode(response); err != nil {
@@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
 	}
 }
 
-func (s *Server) authWrite(next handleFunc) handleFunc {
-	return s.withAuth(next, auth.PermissionWrite)
+func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
+	return s.autorizeTopic(next, auth.PermissionWrite)
 }
 
-func (s *Server) authRead(next handleFunc) handleFunc {
-	return s.withAuth(next, auth.PermissionRead)
+func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
+	return s.autorizeTopic(next, auth.PermissionRead)
 }
 
-func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
+func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 		if s.auth == nil {
 			return next(w, r, v)
@@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
 	visitorID := fmt.Sprintf("ip:%s", ip.String())
 
 	var user *auth.User // may stay nil if no auth header!
-	username, password, ok := extractUserPass(r)
-	if ok {
-		if user, err = s.auth.Authenticate(username, password); err != nil {
-			log.Debug("authentication failed: %s", err.Error())
-			err = errHTTPUnauthorized // Always return visitor, even when error occurs!
-		} else {
-			visitorID = fmt.Sprintf("user:%s", user.Name)
-		}
+	if user, err = s.authenticate(r); err != nil {
+		log.Debug("authentication failed: %s", err.Error())
+		err = errHTTPUnauthorized // Always return visitor, even when error occurs!
+	}
+	if user != nil {
+		visitorID = fmt.Sprintf("user:%s", user.Name)
 	}
 	v = s.visitorFromID(visitorID, ip, user)
 	v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
 	return v, err // Always return visitor, even when error occurs!
 }
 
+func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
+	value := r.Header.Get("Authorization")
+	queryParam := readQueryParam(r, "authorization", "auth")
+	if queryParam != "" {
+		a, err := base64.RawURLEncoding.DecodeString(queryParam)
+		if err != nil {
+			return nil, err
+		}
+		value = string(a)
+	}
+	if value == "" {
+		return nil, nil
+	}
+	if strings.HasPrefix(value, "Bearer") {
+		return s.authenticateBearerAuth(value)
+	}
+	return s.authenticateBasicAuth(r, value)
+}
+
+func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
+	r.Header.Set("Authorization", value)
+	username, password, ok := r.BasicAuth()
+	if !ok {
+		return nil, errors.New("invalid basic auth")
+	}
+	return s.auth.Authenticate(username, password)
+}
+
+func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
+	token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
+	return s.auth.AuthenticateToken(token)
+}
+
 func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
 	s.mu.Lock()
 	defer s.mu.Unlock()
diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go
index 36fd8b51..ba2ab1d7 100644
--- a/server/server_firebase_test.go
+++ b/server/server_firebase_test.go
@@ -18,7 +18,7 @@ type testAuther struct {
 	Allow bool
 }
 
-func (t testAuther) Authenticate(_, _ string) (*auth.User, error) {
+func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) {
 	return nil, errors.New("not used")
 }
 
diff --git a/web/src/app/Api.js b/web/src/app/Api.js
index d2b8c7e5..c106a280 100644
--- a/web/src/app/Api.js
+++ b/web/src/app/Api.js
@@ -1,11 +1,13 @@
 import {
     fetchLinesIterator,
-    maybeWithBasicAuth,
+    maybeWithBasicAuth, maybeWithBearerAuth,
     topicShortUrl,
     topicUrl,
     topicUrlAuth,
     topicUrlJsonPoll,
-    topicUrlJsonPollWithSince, userAuthUrl,
+    topicUrlJsonPollWithSince,
+    userAccountUrl,
+    userAuthUrl,
     userStatsUrl
 } from "./utils";
 import userManager from "./UserManager";
@@ -144,6 +146,20 @@ class Api {
         console.log(`[Api] Stats`, stats);
         return stats;
     }
+
+    async userAccount(baseUrl, token) {
+        const url = userAccountUrl(baseUrl);
+        console.log(`[Api] Fetching user account ${url}`);
+        const response = await fetch(url, {
+            headers: maybeWithBearerAuth({}, token)
+        });
+        if (response.status !== 200) {
+            throw new Error(`Unexpected server response ${response.status}`);
+        }
+        const account = await response.json();
+        console.log(`[Api] Account`, account);
+        return account;
+    }
 }
 
 const api = new Api();
diff --git a/web/src/app/utils.js b/web/src/app/utils.js
index 24ed825f..36184090 100644
--- a/web/src/app/utils.js
+++ b/web/src/app/utils.js
@@ -20,6 +20,7 @@ export const topicUrlAuth = (baseUrl, topic) => `${topicUrl(baseUrl, topic)}/aut
 export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic));
 export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`;
 export const userAuthUrl = (baseUrl) => `${baseUrl}/user/auth`;
+export const userAccountUrl = (baseUrl) => `${baseUrl}/user/account`;
 export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
 export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
 export const expandSecureUrl = (url) => `https://${url}`;
@@ -95,7 +96,6 @@ export const unmatchedTags = (tags) => {
     else return tags.filter(tag => !(tag in emojis));
 }
 
-
 export const maybeWithBasicAuth = (headers, user) => {
     if (user) {
         headers['Authorization'] = `Basic ${encodeBase64(`${user.username}:${user.password}`)}`;
@@ -103,6 +103,13 @@ export const maybeWithBasicAuth = (headers, user) => {
     return headers;
 }
 
+export const maybeWithBearerAuth = (headers, token) => {
+    if (token) {
+        headers['Authorization'] = `Bearer ${token}`;
+    }
+    return headers;
+}
+
 export const basicAuth = (username, password) => {
     return `Basic ${encodeBase64(`${username}:${password}`)}`;
 }
diff --git a/web/src/components/App.js b/web/src/components/App.js
index e74aa3db..e69cfead 100644
--- a/web/src/components/App.js
+++ b/web/src/components/App.js
@@ -25,6 +25,10 @@ import "./i18n"; // Translations!
 import {Backdrop, CircularProgress} from "@mui/material";
 import Home from "./Home";
 import Login from "./Login";
+import i18n from "i18next";
+import api from "../app/Api";
+import prefs from "../app/Prefs";
+import session from "../app/Session";
 
 // TODO races when two tabs are open
 // TODO investigate service workers
@@ -81,6 +85,21 @@ const Layout = () => {
     useBackgroundProcesses();
     useEffect(() => updateTitle(newNotificationsCount), [newNotificationsCount]);
 
+    useEffect(() => {
+        (async () => {
+            const account = await api.userAccount("http://localhost:2586", session.token());
+            if (account) {
+                if (account.language) {
+                    await i18n.changeLanguage(account.language);
+                }
+                if (account.notification) {
+                    if (account.notification.sound) {
+                        await prefs.setSound(account.notification.sound);
+                    }
+                }
+            }
+        })();
+    });
     return (
         <Box sx={{display: 'flex'}}>
             <CssBaseline/>
diff --git a/web/src/components/Login.js b/web/src/components/Login.js
index 0e195973..50edd8d7 100644
--- a/web/src/components/Login.js
+++ b/web/src/components/Login.js
@@ -32,7 +32,7 @@ const Login = () => {
             email: data.get('email'),
             password: data.get('password'),
         });
-        const user ={
+        const user = {
             username: data.get('email'),
             password: data.get('password'),
         }