1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-22 19:33:27 +01:00

Token stuff

This commit is contained in:
Philipp Heckel 2022-12-03 15:20:59 -05:00
parent d3dfeeccc3
commit d499d20a9c
8 changed files with 194 additions and 64 deletions

View file

@ -6,13 +6,17 @@ import (
"regexp"
)
// Auther is a generic interface to implement password-based authentication and authorization
// Auther is a generic interface to implement password and token based authentication and authorization
type Auther interface {
// Authenticate checks username and password and returns a user if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect.
Authenticate(username, password string) (*User, error)
AuthenticateToken(token string) (*User, error)
GenerateToken(user *User) (string, error)
// Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user.
Authorize(user *User, topic string, perm Permission) error
@ -56,10 +60,11 @@ type Manager interface {
// User is a struct that represents a user
type User struct {
Name string
Hash string // password hash (bcrypt)
Role Role
Grants []Grant
Name string
Hash string // password hash (bcrypt)
Role Role
Grants []Grant
Language string
}
// Grant is a struct that represents an access control entry to a topic

View file

@ -6,10 +6,12 @@ import (
"fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/util"
"strings"
)
const (
tokenLength = 32
bcryptCost = 10
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
)
@ -67,7 +69,17 @@ const (
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
COMMIT;
`
selectUserQuery = `SELECT pass, role FROM user WHERE user = ?`
selectUserByNameQuery = `
SELECT user, pass, role, language
FROM user
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT user, pass, role, language
FROM user
JOIN user_token on user.id = user_token.user_id
WHERE token = ?
`
selectTopicPermsQuery = `
SELECT read, write
FROM user_access
@ -90,6 +102,8 @@ const (
deleteAllAccessQuery = `DELETE FROM user_access`
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
)
// Schema management queries
@ -126,7 +140,7 @@ func NewSQLiteAuth(filename string, defaultRead, defaultWrite bool) (*SQLiteAuth
}, nil
}
// Authenticate checks username and password and returns a user if correct. The method
// AuthenticateUser checks username and password and returns a user if correct. The method
// returns in constant-ish time, regardless of whether the user exists or the password is
// correct or incorrect.
func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
@ -145,6 +159,23 @@ func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
return user, nil
}
func (a *SQLiteAuth) AuthenticateToken(token string) (*User, error) {
user, err := a.userByToken(token)
if err != nil {
return nil, ErrUnauthenticated
}
return user, nil
}
func (a *SQLiteAuth) GenerateToken(user *User) (string, error) {
token := util.RandomString(tokenLength)
expires := 1 // FIXME
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil {
return "", err
}
return token, nil
}
// Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user.
func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error {
@ -255,16 +286,29 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
if username == Everyone {
return a.everyoneUser()
}
rows, err := a.db.Query(selectUserQuery, username)
rows, err := a.db.Query(selectUserByNameQuery, username)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *SQLiteAuth) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *SQLiteAuth) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var hash, role string
var username, hash, role string
var language sql.NullString
if !rows.Next() {
return nil, ErrNotFound
}
if err := rows.Scan(&hash, &role); err != nil {
if err := rows.Scan(&username, &hash, &role, &language); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -274,10 +318,11 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
return nil, err
}
return &User{
Name: username,
Hash: hash,
Role: Role(role),
Grants: grants,
Name: username,
Hash: hash,
Role: Role(role),
Grants: grants,
Language: language.String,
}, nil
}

View file

@ -7,6 +7,7 @@ import (
"embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodOptions {
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v)
return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v)
return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
return s.ensureWebEnabled(s.handleTopic)(w, r, v)
}
@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
return nil
}
var sessions = make(map[string]*auth.User) // token-> user
type tokenAuthResponse struct {
Token string `json:"token"`
}
@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
if v.user == nil {
return errHTTPUnauthorized
}
token := util.RandomString(32)
sessions[token] = v.user
token, err := s.auth.GenerateToken(v.user)
if err != nil {
return err
}
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &tokenAuthResponse{
@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
Topic string `json:"topic"`
}
type userNotificationSettingsResponse struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
}
type userPlanResponse struct {
Id int `json:"id"`
Name string `json:"name"`
}
type userAccountResponse struct {
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`
Plan struct {
Id int `json:"id"`
Name string `json:"name"`
} `json:"plan,omitempty"`
Notification struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
} `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`
Plan *userPlanResponse `json:"plan,omitempty"`
Notification *userNotificationSettingsResponse `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
}
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
var response *userAccountResponse
response := &userAccountResponse{}
if v.user != nil {
response = &userAccountResponse{
Username: v.user.Name,
Role: string(v.user.Role),
Language: "en_US",
response.Username = v.user.Name
response.Role = string(v.user.Role)
response.Language = v.user.Language
response.Notification = &userNotificationSettingsResponse{
Sound: "dadum",
}
} else {
response = &userAccountResponse{
Username: "anonymous",
Username: auth.Everyone,
Role: string(auth.RoleAnonymous),
}
}
if err := json.NewEncoder(w).Encode(response); err != nil {
@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
}
}
func (s *Server) authWrite(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionWrite)
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionWrite)
}
func (s *Server) authRead(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionRead)
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionRead)
}
func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auth == nil {
return next(w, r, v)
@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
visitorID := fmt.Sprintf("ip:%s", ip.String())
var user *auth.User // may stay nil if no auth header!
username, password, ok := extractUserPass(r)
if ok {
if user, err = s.auth.Authenticate(username, password); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
} else {
visitorID = fmt.Sprintf("user:%s", user.Name)
}
if user, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
if user != nil {
visitorID = fmt.Sprintf("user:%s", user.Name)
}
v = s.visitorFromID(visitorID, ip, user)
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
return v, err // Always return visitor, even when error occurs!
}
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
value := r.Header.Get("Authorization")
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
}
value = string(a)
}
if value == "" {
return nil, nil
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(value)
}
return s.authenticateBasicAuth(r, value)
}
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
r.Header.Set("Authorization", value)
username, password, ok := r.BasicAuth()
if !ok {
return nil, errors.New("invalid basic auth")
}
return s.auth.Authenticate(username, password)
}
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
return s.auth.AuthenticateToken(token)
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()

View file

@ -18,7 +18,7 @@ type testAuther struct {
Allow bool
}
func (t testAuther) Authenticate(_, _ string) (*auth.User, error) {
func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) {
return nil, errors.New("not used")
}

View file

@ -1,11 +1,13 @@
import {
fetchLinesIterator,
maybeWithBasicAuth,
maybeWithBasicAuth, maybeWithBearerAuth,
topicShortUrl,
topicUrl,
topicUrlAuth,
topicUrlJsonPoll,
topicUrlJsonPollWithSince, userAuthUrl,
topicUrlJsonPollWithSince,
userAccountUrl,
userAuthUrl,
userStatsUrl
} from "./utils";
import userManager from "./UserManager";
@ -144,6 +146,20 @@ class Api {
console.log(`[Api] Stats`, stats);
return stats;
}
async userAccount(baseUrl, token) {
const url = userAccountUrl(baseUrl);
console.log(`[Api] Fetching user account ${url}`);
const response = await fetch(url, {
headers: maybeWithBearerAuth({}, token)
});
if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
const account = await response.json();
console.log(`[Api] Account`, account);
return account;
}
}
const api = new Api();

View file

@ -20,6 +20,7 @@ export const topicUrlAuth = (baseUrl, topic) => `${topicUrl(baseUrl, topic)}/aut
export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic));
export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`;
export const userAuthUrl = (baseUrl) => `${baseUrl}/user/auth`;
export const userAccountUrl = (baseUrl) => `${baseUrl}/user/account`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`;
@ -95,7 +96,6 @@ export const unmatchedTags = (tags) => {
else return tags.filter(tag => !(tag in emojis));
}
export const maybeWithBasicAuth = (headers, user) => {
if (user) {
headers['Authorization'] = `Basic ${encodeBase64(`${user.username}:${user.password}`)}`;
@ -103,6 +103,13 @@ export const maybeWithBasicAuth = (headers, user) => {
return headers;
}
export const maybeWithBearerAuth = (headers, token) => {
if (token) {
headers['Authorization'] = `Bearer ${token}`;
}
return headers;
}
export const basicAuth = (username, password) => {
return `Basic ${encodeBase64(`${username}:${password}`)}`;
}

View file

@ -25,6 +25,10 @@ import "./i18n"; // Translations!
import {Backdrop, CircularProgress} from "@mui/material";
import Home from "./Home";
import Login from "./Login";
import i18n from "i18next";
import api from "../app/Api";
import prefs from "../app/Prefs";
import session from "../app/Session";
// TODO races when two tabs are open
// TODO investigate service workers
@ -81,6 +85,21 @@ const Layout = () => {
useBackgroundProcesses();
useEffect(() => updateTitle(newNotificationsCount), [newNotificationsCount]);
useEffect(() => {
(async () => {
const account = await api.userAccount("http://localhost:2586", session.token());
if (account) {
if (account.language) {
await i18n.changeLanguage(account.language);
}
if (account.notification) {
if (account.notification.sound) {
await prefs.setSound(account.notification.sound);
}
}
}
})();
});
return (
<Box sx={{display: 'flex'}}>
<CssBaseline/>

View file

@ -32,7 +32,7 @@ const Login = () => {
email: data.get('email'),
password: data.get('password'),
});
const user ={
const user = {
username: data.get('email'),
password: data.get('password'),
}