mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-05-18 13:04:34 +02:00
WIPWIPWIP
This commit is contained in:
parent
84dca41008
commit
2772a38dae
16 changed files with 644 additions and 66 deletions
server
154
server/server.go
154
server/server.go
|
@ -43,7 +43,7 @@ type Server struct {
|
|||
smtpServerBackend *smtpBackend
|
||||
smtpSender mailer
|
||||
topics map[string]*topic
|
||||
visitors map[netip.Addr]*visitor
|
||||
visitors map[string]*visitor // ip:<ip> or user:<user>
|
||||
firebaseClient *firebaseClient
|
||||
messages int64
|
||||
auth auth.Auther
|
||||
|
@ -69,7 +69,9 @@ var (
|
|||
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
|
||||
|
||||
webConfigPath = "/config.js"
|
||||
userStatsPath = "/user/stats"
|
||||
userStatsPath = "/user/stats" // FIXME get rid of this in favor of /user/account
|
||||
userAuthPath = "/user/auth"
|
||||
userAccountPath = "/user/account"
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
staticRegex = regexp.MustCompile(`^/static/.+`)
|
||||
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
|
||||
|
@ -151,7 +153,7 @@ func New(conf *Config) (*Server, error) {
|
|||
smtpSender: mailer,
|
||||
topics: topics,
|
||||
auth: auther,
|
||||
visitors: make(map[netip.Addr]*visitor),
|
||||
visitors: make(map[string]*visitor),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -255,12 +257,15 @@ func (s *Server) Stop() {
|
|||
}
|
||||
|
||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||
v := s.visitor(r)
|
||||
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
||||
if log.IsTrace() {
|
||||
log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
|
||||
v, err := s.visitor(r) // Note: Always returns v, even when error is returned
|
||||
if err == nil {
|
||||
log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
||||
if log.IsTrace() {
|
||||
log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
|
||||
}
|
||||
err = s.handleInternal(w, r, v)
|
||||
}
|
||||
if err := s.handleInternal(w, r, v); err != nil {
|
||||
if err != nil {
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
||||
if isNormalError {
|
||||
|
@ -300,6 +305,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == userStatsPath {
|
||||
return s.handleUserStats(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == userAuthPath {
|
||||
return s.handleUserAuth(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == userAccountPath {
|
||||
return s.handleUserAccount(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
|
||||
return s.handleMatrixDiscovery(w)
|
||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||
|
@ -394,6 +403,72 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
|
|||
return nil
|
||||
}
|
||||
|
||||
var sessions = make(map[string]*auth.User) // token-> user
|
||||
|
||||
type tokenAuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
// TODO rate limit
|
||||
if v.user == nil {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
token := util.RandomString(32)
|
||||
sessions[token] = v.user
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
response := &tokenAuthResponse{
|
||||
Token: token,
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type userSubscriptionResponse struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
type userAccountResponse struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Plan struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"plan,omitempty"`
|
||||
Notification struct {
|
||||
Sound string `json:"sound"`
|
||||
MinPriority string `json:"min_priority"`
|
||||
DeleteAfter int `json:"delete_after"`
|
||||
} `json:"notification,omitempty"`
|
||||
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
var response *userAccountResponse
|
||||
if v.user != nil {
|
||||
response = &userAccountResponse{
|
||||
Username: v.user.Name,
|
||||
Role: string(v.user.Role),
|
||||
Language: "en_US",
|
||||
}
|
||||
} else {
|
||||
response = &userAccountResponse{
|
||||
Username: "anonymous",
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
r.URL.Path = webSiteDir + r.URL.Path
|
||||
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
|
||||
|
@ -1221,7 +1296,7 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||
if s.firebaseClient == nil {
|
||||
return
|
||||
}
|
||||
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
|
||||
v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
|
||||
for {
|
||||
select {
|
||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||
|
@ -1253,7 +1328,7 @@ func (s *Server) sendDelayedMessages() error {
|
|||
return err
|
||||
}
|
||||
for _, m := range messages {
|
||||
v := s.visitorFromIP(m.Sender)
|
||||
v := s.visitorFromID(fmt.Sprintf("ip:%s", m.Sender.String()), m.Sender, nil) // FIXME: This is wrong wrong wrong
|
||||
if err := s.sendDelayedMessage(v, m); err != nil {
|
||||
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
||||
}
|
||||
|
@ -1395,16 +1470,8 @@ func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var user *auth.User // may stay nil if no auth header!
|
||||
username, password, ok := extractUserPass(r)
|
||||
if ok {
|
||||
if user, err = s.auth.Authenticate(username, password); err != nil {
|
||||
log.Info("authentication failed: %s", err.Error())
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
}
|
||||
for _, t := range topics {
|
||||
if err := s.auth.Authorize(user, t.ID, perm); err != nil {
|
||||
if err := s.auth.Authorize(v.user, t.ID, perm); err != nil {
|
||||
log.Info("unauthorized: %s", err.Error())
|
||||
return errHTTPForbidden
|
||||
}
|
||||
|
@ -1435,8 +1502,39 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
|||
}
|
||||
|
||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||
// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
||||
func (s *Server) visitor(r *http.Request) *visitor {
|
||||
// Note that this function will always return a visitor, even if an error occurs.
|
||||
func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
||||
ip := s.extractIPAddress(r)
|
||||
visitorID := fmt.Sprintf("ip:%s", ip.String())
|
||||
|
||||
var user *auth.User // may stay nil if no auth header!
|
||||
username, password, ok := extractUserPass(r)
|
||||
if ok {
|
||||
if user, err = s.auth.Authenticate(username, password); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
} else {
|
||||
visitorID = fmt.Sprintf("user:%s", user.Name)
|
||||
}
|
||||
}
|
||||
v = s.visitorFromID(visitorID, ip, user)
|
||||
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[visitorID]
|
||||
if !exists {
|
||||
s.visitors[visitorID] = newVisitor(s.config, s.messageCache, ip, user)
|
||||
return s.visitors[visitorID]
|
||||
}
|
||||
v.Keepalive()
|
||||
return v
|
||||
}
|
||||
|
||||
func (s *Server) extractIPAddress(r *http.Request) netip.Addr {
|
||||
remoteAddr := r.RemoteAddr
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||
ip := addrPort.Addr()
|
||||
|
@ -1461,17 +1559,5 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
|||
ip = realIP
|
||||
}
|
||||
}
|
||||
return s.visitorFromIP(ip)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, exists := s.visitors[ip]
|
||||
if !exists {
|
||||
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
||||
return s.visitors[ip]
|
||||
}
|
||||
v.Keepalive()
|
||||
return v
|
||||
return ip
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue