1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-12-22 17:52:30 +01:00

Add tests for users, slightly change API a bit

This commit is contained in:
binwiederhier 2023-05-15 10:42:24 -04:00
parent 4f4165f46f
commit f14f0aaa26
5 changed files with 181 additions and 16 deletions

View file

@ -82,8 +82,8 @@ var (
apiHealthPath = "/v1/health"
apiStatsPath = "/v1/stats"
apiTiersPath = "/v1/tiers"
apiUserPath = "/v1/user"
apiAccessPath = "/v1/access"
apiUsersPath = "/v1/users"
apiUsersAccessPath = "/v1/users/access"
apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password"
@ -413,13 +413,15 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUserPath {
return s.ensureAdmin(s.handleUserAdd)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiUserPath {
return s.ensureAdmin(s.handleUserDelete)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiAccessPath {
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersAdd)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersDelete)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiUsersAccessPath {
return s.ensureAdmin(s.handleAccessAllow)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccessPath {
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersAccessPath {
return s.ensureAdmin(s.handleAccessReset)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
@ -1456,7 +1458,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
return s.topicFromID(parts[1])
}
// topicFromID returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
// topicsFromPath returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
parts := strings.Split(path, "/")
if len(parts) < 2 {

View file

@ -5,7 +5,39 @@ import (
"net/http"
)
func (s *Server) handleUserAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
users, err := s.userManager.Users()
if err != nil {
return err
}
grants, err := s.userManager.AllGrants()
if err != nil {
return err
}
usersResponse := make([]*apiUserResponse, len(users))
for i, u := range users {
tier := ""
if u.Tier != nil {
tier = u.Tier.Code
}
userGrants := make([]*apiUserGrantResponse, len(grants[u.ID]))
for i, g := range grants[u.ID] {
userGrants[i] = &apiUserGrantResponse{
Topic: g.TopicPattern,
Permission: g.Allow.String(),
}
}
usersResponse[i] = &apiUserResponse{
Username: u.Name,
Role: string(u.Role),
Tier: tier,
Grants: userGrants,
}
}
return s.writeJSON(w, usersResponse)
}
func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserAddRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
@ -38,7 +70,7 @@ func (s *Server) handleUserAdd(w http.ResponseWriter, r *http.Request, v *visito
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleUserDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
@ -65,6 +97,12 @@ func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *vi
if err != nil {
return err
}
_, err = s.userManager.User(req.Username)
if err == user.ErrUserNotFound {
return errHTTPBadRequestUserNotFound
} else if err != nil {
return err
}
permission, err := user.ParsePermission(req.Permission)
if err != nil {
return errHTTPBadRequestPermissionInvalid

View file

@ -9,6 +9,87 @@ import (
"time"
)
func TestUser_AddRemove(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create user via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_AddRemove_Failures(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Cannot create user with invalid username
rr := request(t, s, "PUT", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "PUT", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestAccess_AllowReset(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
@ -26,7 +107,7 @@ func TestAccess_AllowReset(t *testing.T) {
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@ -38,7 +119,7 @@ func TestAccess_AllowReset(t *testing.T) {
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@ -60,7 +141,7 @@ func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
@ -88,7 +169,7 @@ func TestAccess_AllowReset_KillConnection(t *testing.T) {
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)

View file

@ -251,13 +251,25 @@ type apiUserAddRequest struct {
// Do not add 'role' here. We don't want to add admins via the API.
}
type apiUserResponse struct {
Username string `json:"username"`
Role string `json:"role"`
Tier string `json:"tier,omitempty"`
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
}
type apiUserGrantResponse struct {
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
}
type apiUserDeleteRequest struct {
Username string `json:"username"`
}
type apiAccessAllowRequest struct {
Username string `json:"username"`
Topic string `json:"topic"`
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
}

View file

@ -185,6 +185,11 @@ const (
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
`
selectUserAllAccessQuery = `
SELECT user_id, topic, read, write
FROM user_access
ORDER BY write DESC, read DESC, topic
`
selectUserAccessQuery = `
SELECT topic, read, write
FROM user_access
@ -966,6 +971,33 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
return user, nil
}
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
func (a *Manager) AllGrants() (map[string][]Grant, error) {
rows, err := a.db.Query(selectUserAllAccessQuery)
if err != nil {
return nil, err
}
defer rows.Close()
grants := make(map[string][]Grant, 0)
for rows.Next() {
var userID, topic string
var read, write bool
if err := rows.Scan(&userID, &topic, &read, &write); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
if _, ok := grants[userID]; !ok {
grants[userID] = make([]Grant, 0)
}
grants[userID] = append(grants[userID], Grant{
TopicPattern: fromSQLWildcard(topic),
Allow: NewPermission(read, write),
})
}
return grants, nil
}
// Grants returns all user-specific access control entries
func (a *Manager) Grants(username string) ([]Grant, error) {
rows, err := a.db.Query(selectUserAccessQuery, username)