mirror of
https://github.com/binwiederhier/ntfy.git
synced 2025-05-30 10:25:40 +02:00
Token stuff
This commit is contained in:
parent
d3dfeeccc3
commit
d499d20a9c
8 changed files with 194 additions and 64 deletions
server
128
server/server.go
128
server/server.go
|
@ -7,6 +7,7 @@ import (
|
|||
"embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
} else if r.Method == http.MethodOptions {
|
||||
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
|
||||
return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v)
|
||||
return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
|
||||
return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v)
|
||||
return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
|
||||
return s.ensureWebEnabled(s.handleTopic)(w, r, v)
|
||||
}
|
||||
|
@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
|
|||
return nil
|
||||
}
|
||||
|
||||
var sessions = make(map[string]*auth.User) // token-> user
|
||||
|
||||
type tokenAuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
|
|||
if v.user == nil {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
token := util.RandomString(32)
|
||||
sessions[token] = v.user
|
||||
token, err := s.auth.GenerateToken(v.user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
response := &tokenAuthResponse{
|
||||
|
@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
|
|||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
type userNotificationSettingsResponse struct {
|
||||
Sound string `json:"sound"`
|
||||
MinPriority string `json:"min_priority"`
|
||||
DeleteAfter int `json:"delete_after"`
|
||||
}
|
||||
|
||||
type userPlanResponse struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type userAccountResponse struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Plan struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"plan,omitempty"`
|
||||
Notification struct {
|
||||
Sound string `json:"sound"`
|
||||
MinPriority string `json:"min_priority"`
|
||||
DeleteAfter int `json:"delete_after"`
|
||||
} `json:"notification,omitempty"`
|
||||
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Plan *userPlanResponse `json:"plan,omitempty"`
|
||||
Notification *userNotificationSettingsResponse `json:"notification,omitempty"`
|
||||
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
var response *userAccountResponse
|
||||
response := &userAccountResponse{}
|
||||
if v.user != nil {
|
||||
response = &userAccountResponse{
|
||||
Username: v.user.Name,
|
||||
Role: string(v.user.Role),
|
||||
Language: "en_US",
|
||||
response.Username = v.user.Name
|
||||
response.Role = string(v.user.Role)
|
||||
response.Language = v.user.Language
|
||||
response.Notification = &userNotificationSettingsResponse{
|
||||
Sound: "dadum",
|
||||
}
|
||||
} else {
|
||||
response = &userAccountResponse{
|
||||
Username: "anonymous",
|
||||
Username: auth.Everyone,
|
||||
Role: string(auth.RoleAnonymous),
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
|
@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) authWrite(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, auth.PermissionWrite)
|
||||
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionWrite)
|
||||
}
|
||||
|
||||
func (s *Server) authRead(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, auth.PermissionRead)
|
||||
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionRead)
|
||||
}
|
||||
|
||||
func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
|
||||
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.auth == nil {
|
||||
return next(w, r, v)
|
||||
|
@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
|||
visitorID := fmt.Sprintf("ip:%s", ip.String())
|
||||
|
||||
var user *auth.User // may stay nil if no auth header!
|
||||
username, password, ok := extractUserPass(r)
|
||||
if ok {
|
||||
if user, err = s.auth.Authenticate(username, password); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
} else {
|
||||
visitorID = fmt.Sprintf("user:%s", user.Name)
|
||||
}
|
||||
if user, err = s.authenticate(r); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
}
|
||||
if user != nil {
|
||||
visitorID = fmt.Sprintf("user:%s", user.Name)
|
||||
}
|
||||
v = s.visitorFromID(visitorID, ip, user)
|
||||
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
||||
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
|
||||
value := r.Header.Get("Authorization")
|
||||
queryParam := readQueryParam(r, "authorization", "auth")
|
||||
if queryParam != "" {
|
||||
a, err := base64.RawURLEncoding.DecodeString(queryParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value = string(a)
|
||||
}
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.HasPrefix(value, "Bearer") {
|
||||
return s.authenticateBearerAuth(value)
|
||||
}
|
||||
return s.authenticateBasicAuth(r, value)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
|
||||
r.Header.Set("Authorization", value)
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, errors.New("invalid basic auth")
|
||||
}
|
||||
return s.auth.Authenticate(username, password)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
||||
return s.auth.AuthenticateToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue