1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-06-12 15:43:23 +02:00

WIP: Auth in 80 lines of code :-)

This commit is contained in:
Philipp Heckel 2022-01-21 22:22:27 -05:00
parent aab705f4a4
commit 2181227a6e
3 changed files with 83 additions and 12 deletions
server

View file

@ -46,6 +46,7 @@ type Server struct {
firebase subscriber
mailer mailer
messages int64
auther auther
cache cache
fileCache *fileCache
closeChan chan bool
@ -57,6 +58,9 @@ type indexPage struct {
CacheDuration time.Duration
}
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
var (
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) {
firebase: firebaseSubscriber,
mailer: mailer,
topics: topics,
auther: &memAuther{},
visitors: make(map[string]*visitor),
}, nil
}
@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r)
if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r)
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.handleDocs(w, r)
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
return s.withRateLimit(w, r, s.handleFile)
return s.limitRequests(s.handleFile)(w, r, v)
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
return s.handleTopic(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish)
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish)
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeJSON)
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeSSE)
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeRaw)
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeWS)
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
}
return errHTTPNotFound
}
@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error {
return nil
}
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
v := s.visitor(r)
if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests
func (s *Server) limitRequests(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
func (s *Server) authWrite(next handleFunc) handleFunc {
return s.withAuth(next, permWrite)
}
func (s *Server) authRead(next handleFunc) handleFunc {
return s.withAuth(next, permRead)
}
func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auther == nil {
return next(w, r, v)
}
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
}
user, pass, ok := r.BasicAuth()
if ok {
if !s.auther.Authenticate(user, pass) {
return errHTTPUnauthorized
}
} else {
user = "" // Just in case
}
if !s.auther.Authorize(user, t.ID, perm) {
return errHTTPUnauthorized
}
return next(w, r, v)
}
return handler(w, r, v)
}
// visitor creates or retrieves a rate.Limiter for the given visitor.