1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2025-05-18 13:04:34 +02:00

WIPWIPWIP

This commit is contained in:
Philipp Heckel 2022-12-02 15:37:48 -05:00
parent 84dca41008
commit 2772a38dae
16 changed files with 644 additions and 66 deletions
server

View file

@ -43,7 +43,7 @@ type Server struct {
smtpServerBackend *smtpBackend
smtpSender mailer
topics map[string]*topic
visitors map[netip.Addr]*visitor
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
auth auth.Auther
@ -69,7 +69,9 @@ var (
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js"
userStatsPath = "/user/stats"
userStatsPath = "/user/stats" // FIXME get rid of this in favor of /user/account
userAuthPath = "/user/auth"
userAccountPath = "/user/account"
matrixPushPath = "/_matrix/push/v1/notify"
staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
@ -151,7 +153,7 @@ func New(conf *Config) (*Server, error) {
smtpSender: mailer,
topics: topics,
auth: auther,
visitors: make(map[netip.Addr]*visitor),
visitors: make(map[string]*visitor),
}, nil
}
@ -255,12 +257,15 @@ func (s *Server) Stop() {
}
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
v := s.visitor(r)
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
if log.IsTrace() {
log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
if err == nil {
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
if log.IsTrace() {
log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
}
err = s.handleInternal(w, r, v)
}
if err := s.handleInternal(w, r, v); err != nil {
if err != nil {
if websocket.IsWebSocketUpgrade(r) {
isNormalError := strings.Contains(err.Error(), "i/o timeout")
if isNormalError {
@ -300,6 +305,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == userStatsPath {
return s.handleUserStats(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == userAuthPath {
return s.handleUserAuth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == userAccountPath {
return s.handleUserAccount(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@ -394,6 +403,72 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
return nil
}
var sessions = make(map[string]*auth.User) // token-> user
type tokenAuthResponse struct {
Token string `json:"token"`
}
func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visitor) error {
// TODO rate limit
if v.user == nil {
return errHTTPUnauthorized
}
token := util.RandomString(32)
sessions[token] = v.user
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
response := &tokenAuthResponse{
Token: token,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
type userSubscriptionResponse struct {
BaseURL string `json:"base_url"`
Topic string `json:"topic"`
}
type userAccountResponse struct {
Username string `json:"username"`
Role string `json:"role,omitempty"`
Language string `json:"language,omitempty"`
Plan struct {
Id int `json:"id"`
Name string `json:"name"`
} `json:"plan,omitempty"`
Notification struct {
Sound string `json:"sound"`
MinPriority string `json:"min_priority"`
DeleteAfter int `json:"delete_after"`
} `json:"notification,omitempty"`
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
}
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
var response *userAccountResponse
if v.user != nil {
response = &userAccountResponse{
Username: v.user.Name,
Role: string(v.user.Role),
Language: "en_US",
}
} else {
response = &userAccountResponse{
Username: "anonymous",
}
}
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
r.URL.Path = webSiteDir + r.URL.Path
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
@ -1221,7 +1296,7 @@ func (s *Server) runFirebaseKeepaliver() {
if s.firebaseClient == nil {
return
}
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
@ -1253,7 +1328,7 @@ func (s *Server) sendDelayedMessages() error {
return err
}
for _, m := range messages {
v := s.visitorFromIP(m.Sender)
v := s.visitorFromID(fmt.Sprintf("ip:%s", m.Sender.String()), m.Sender, nil) // FIXME: This is wrong wrong wrong
if err := s.sendDelayedMessage(v, m); err != nil {
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
}
@ -1395,16 +1470,8 @@ func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
if err != nil {
return err
}
var user *auth.User // may stay nil if no auth header!
username, password, ok := extractUserPass(r)
if ok {
if user, err = s.auth.Authenticate(username, password); err != nil {
log.Info("authentication failed: %s", err.Error())
return errHTTPUnauthorized
}
}
for _, t := range topics {
if err := s.auth.Authorize(user, t.ID, perm); err != nil {
if err := s.auth.Authorize(v.user, t.ID, perm); err != nil {
log.Info("unauthorized: %s", err.Error())
return errHTTPForbidden
}
@ -1435,8 +1502,39 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
}
// visitor creates or retrieves a rate.Limiter for the given visitor.
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
func (s *Server) visitor(r *http.Request) *visitor {
// Note that this function will always return a visitor, even if an error occurs.
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
ip := s.extractIPAddress(r)
visitorID := fmt.Sprintf("ip:%s", ip.String())
var user *auth.User // may stay nil if no auth header!
username, password, ok := extractUserPass(r)
if ok {
if user, err = s.auth.Authenticate(username, password); err != nil {
log.Debug("authentication failed: %s", err.Error())
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
} else {
visitorID = fmt.Sprintf("user:%s", user.Name)
}
}
v = s.visitorFromID(visitorID, ip, user)
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
return v, err // Always return visitor, even when error occurs!
}
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[visitorID]
if !exists {
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, ip, user)
return s.visitors[visitorID]
}
v.Keepalive()
return v
}
func (s *Server) extractIPAddress(r *http.Request) netip.Addr {
remoteAddr := r.RemoteAddr
addrPort, err := netip.ParseAddrPort(remoteAddr)
ip := addrPort.Addr()
@ -1461,17 +1559,5 @@ func (s *Server) visitor(r *http.Request) *visitor {
ip = realIP
}
}
return s.visitorFromIP(ip)
}
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
s.mu.Lock()
defer s.mu.Unlock()
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
return s.visitors[ip]
}
v.Keepalive()
return v
return ip
}