1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-05-30 10:25:40 +02:00

Token stuff

This commit is contained in:
Philipp Heckel 2022-12-03 15:20:59 -05:00
parent d3dfeeccc3
commit d499d20a9c
8 changed files with 194 additions and 64 deletions
server

View file

@ -7,6 +7,7 @@ import (
"embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
} else if r.Method == http.MethodOptions {
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v)
return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v)
return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
return s.ensureWebEnabled(s.handleTopic)(w, r, v)
}
@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
return nil
}
var sessions = make(map[string]*auth.User) // token-> user
type tokenAuthResponse struct {
Token string `json:"token"`
}
@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
if v.user == nil {
return errHTTPUnauthorized
}
token := util.RandomString(32)
sessions[token] = v.user
token, err := s.auth.GenerateToken(v.user)
if err != nil {
return err
}
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &tokenAuthResponse{
@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
Topic string `json:"topic"`
}
type userNotificationSettingsResponse struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
}
type userPlanResponse struct {
Id int `json:"id"`
Name string `json:"name"`
}
type userAccountResponse struct {
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`
Plan struct {
Id int `json:"id"`
Name string `json:"name"`
} `json:"plan,omitempty"`
Notification struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
} `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`
Plan *userPlanResponse `json:"plan,omitempty"`
Notification *userNotificationSettingsResponse `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
}
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
var response *userAccountResponse
response := &userAccountResponse{}
if v.user != nil {
response = &userAccountResponse{
Username: v.user.Name,
Role: string(v.user.Role),
Language: "en_US",
response.Username = v.user.Name
response.Role = string(v.user.Role)
response.Language = v.user.Language
response.Notification = &userNotificationSettingsResponse{
Sound: "dadum",
}
} else {
response = &userAccountResponse{
Username: "anonymous",
Username: auth.Everyone,
Role: string(auth.RoleAnonymous),
}
}
if err := json.NewEncoder(w).Encode(response); err != nil {
@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
}
}
func (s *Server) authWrite(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionWrite)
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionWrite)
}
func (s *Server) authRead(next handleFunc) handleFunc {
return s.withAuth(next, auth.PermissionRead)
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
return s.autorizeTopic(next, auth.PermissionRead)
}
func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auth == nil {
return next(w, r, v)
@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
visitorID := fmt.Sprintf("ip:%s", ip.String())
var user *auth.User // may stay nil if no auth header!
username, password, ok := extractUserPass(r)
if ok {
if user, err = s.auth.Authenticate(username, password); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
} else {
visitorID = fmt.Sprintf("user:%s", user.Name)
}
if user, err = s.authenticate(r); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
}
if user != nil {
visitorID = fmt.Sprintf("user:%s", user.Name)
}
v = s.visitorFromID(visitorID, ip, user)
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
return v, err // Always return visitor, even when error occurs!
}
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
value := r.Header.Get("Authorization")
queryParam := readQueryParam(r, "authorization", "auth")
if queryParam != "" {
a, err := base64.RawURLEncoding.DecodeString(queryParam)
if err != nil {
return nil, err
}
value = string(a)
}
if value == "" {
return nil, nil
}
if strings.HasPrefix(value, "Bearer") {
return s.authenticateBearerAuth(value)
}
return s.authenticateBasicAuth(r, value)
}
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
r.Header.Set("Authorization", value)
username, password, ok := r.BasicAuth()
if !ok {
return nil, errors.New("invalid basic auth")
}
return s.auth.Authenticate(username, password)
}
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
return s.auth.AuthenticateToken(token)
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()