1
0
Fork 0
mirror of https://github.com/binwiederhier/ntfy.git synced 2024-11-21 19:03:26 +01:00

Self-review, round 2

This commit is contained in:
binwiederhier 2023-02-09 15:24:12 -05:00
parent bcb22d8d4c
commit e6bb5f484c
24 changed files with 288 additions and 183 deletions

View file

@ -61,7 +61,7 @@ var cmdTier = &cli.Command{
Tiers can be used to grant users higher limits, such as daily message limits, attachment size, or
make it possible for users to reserve topics.
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
Examples:
@ -102,7 +102,7 @@ Examples:
After updating a tier, you may have to restart the ntfy server to apply them
to all visitors.
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
Examples:
@ -124,7 +124,7 @@ Examples:
You cannot remove a tier if there are users associated with a tier. Use "ntfy user change-tier"
to remove or switch their tier first.
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
Example:
@ -138,7 +138,7 @@ Example:
Action: execTierList,
Description: `Shows a list of all configured tiers.
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
`,
},

View file

@ -27,8 +27,26 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
require.Contains(t, stderr.String(), "- Message limit: 1234")
app, _, _, stderr = newTestApp()
require.Nil(t, runTierCommand(app, conf, "change", "--message-limit", "999", "pro"))
require.Nil(t, runTierCommand(app, conf, "change",
"--message-limit=999",
"--message-expiry-duration=99h",
"--email-limit=91",
"--reservation-limit=98",
"--attachment-file-size-limit=100m",
"--attachment-expiry-duration=7h",
"--attachment-total-size-limit=10G",
"--attachment-bandwidth-limit=100G",
"--stripe-price-id=price_991",
"pro",
))
require.Contains(t, stderr.String(), "- Message limit: 999")
require.Contains(t, stderr.String(), "- Message expiry duration: 99h")
require.Contains(t, stderr.String(), "- Email limit: 91")
require.Contains(t, stderr.String(), "- Reservation limit: 98")
require.Contains(t, stderr.String(), "- Attachment file size limit: 100.0 MB")
require.Contains(t, stderr.String(), "- Attachment expiry duration: 7h")
require.Contains(t, stderr.String(), "- Attachment total size limit: 10.0 GB")
require.Contains(t, stderr.String(), "- Stripe price: price_991")
app, _, _, stderr = newTestApp()
require.Nil(t, runTierCommand(app, conf, "remove", "pro"))

View file

@ -42,6 +42,9 @@ User access tokens can be used to publish, subscribe, or perform any other user-
Tokens have full access, and can perform any task a user can do. They are meant to be used to
avoid spreading the password to various places.
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
Examples:
ntfy token add phil # Create token for user phil which never expires
ntfy token add --expires=2d phil # Create token for user phil which expires in 2 days
@ -66,7 +69,7 @@ Example:
Action: execTokenList,
Description: `Shows a list of all tokens.
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.`,
},
},

View file

@ -141,7 +141,7 @@ Example:
This command is an alias to calling 'ntfy access' (display access control list).
This is a server-only command. It directly reads from the user.db as defined in the server config
This is a server-only command. It directly reads from user.db as defined in the server config
file server.yml. The command only works if 'auth-file' is properly defined.
`,
},

View file

@ -13,6 +13,7 @@ import (
const (
tagField = "tag"
errorField = "error"
timeTakenField = "time_taken_ms"
exitCodeField = "exit_code"
timestampFormat = "2006-01-02T15:04:05.999Z07:00"
)
@ -80,6 +81,13 @@ func (e *Event) Time(t time.Time) *Event {
return e
}
// Timing runs f and records the time if took to execute it in "time_taken_ms"
func (e *Event) Timing(f func()) *Event {
start := time.Now()
f()
return e.Field(timeTakenField, time.Since(start).Milliseconds())
}
// Err adds an "error" field to the log event
func (e *Event) Err(err error) *Event {
if err == nil {

View file

@ -78,6 +78,11 @@ func Time(time time.Time) *Event {
return newEvent().Time(time)
}
// Timing runs f and records the time if took to execute it in "time_taken_ms"
func Timing(f func()) *Event {
return newEvent().Timing(f)
}
// CurrentLevel returns the current log level
func CurrentLevel() Level {
mu.Lock()

View file

@ -2,6 +2,7 @@ package log
import (
"bytes"
"encoding/json"
"github.com/stretchr/testify/require"
"os"
"testing"
@ -131,6 +132,25 @@ func TestLog_NoAllocIfNotPrinted(t *testing.T) {
require.Equal(t, expected, out.String())
}
func TestLog_Timing(t *testing.T) {
t.Cleanup(resetState)
var out bytes.Buffer
SetOutput(&out)
SetFormat(JSONFormat)
Timing(func() { time.Sleep(300 * time.Millisecond) }).
Time(time.Unix(12, 0).UTC()).
Info("A thing that takes a while")
var ev struct {
TimeTakenMs int64 `json:"time_taken_ms"`
}
require.Nil(t, json.Unmarshal(out.Bytes(), &ev))
require.True(t, ev.TimeTakenMs >= 300)
require.Contains(t, out.String(), `{"time":"1970-01-01T00:00:12Z","level":"INFO","message":"A thing that takes a while","time_taken_ms":`)
}
type fakeError struct {
Code int
Message string

View file

@ -164,6 +164,7 @@ func NewConfig() *Config {
AttachmentExpiryDuration: DefaultAttachmentExpiryDuration,
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
DisallowedTopics: DefaultDisallowedTopics,
WebRootIsApp: false,
DelayedSenderInterval: DefaultDelayedSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,

View file

@ -51,6 +51,8 @@ const (
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
COMMIT;
`
@ -215,6 +217,8 @@ const (
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
@ -883,8 +887,5 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
if _, err := tx.Exec(updateSchemaVersion, 10); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
return nil // Update this when a new version is added
return tx.Commit()
}

View file

@ -37,12 +37,13 @@ import (
- HIGH Docs
- tiers
- api
- tokens
- HIGH Self-review
- MEDIUM: Test for expiring messages after reservation removal
- MEDIUM: uploading attachments leads to 404 -- race
- MEDIUM: Do not call tiers endoint when payments is not enabled
- MEDIUM: Test new token endpoints & never-expiring token
- LOW: UI: Flickering upgrade banner when logging in
- LOW: Menu item -> popup click should not open page
*/
@ -140,6 +141,7 @@ const (
const (
tagStartup = "startup"
tagPublish = "publish"
tagSubscribe = "subscribe"
tagFirebase = "firebase"
tagEmail = "email" // Send email
tagSMTP = "smtp" // Receive email
@ -649,7 +651,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
}
u := v.User()
if s.userManager != nil && u != nil && u.Tier != nil {
go s.userManager.EnqueueStats(u.ID, v.Stats())
go s.userManager.EnqueueUserStats(u.ID, v.Stats())
}
s.mu.Lock()
s.messages++
@ -956,8 +958,8 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v
}
func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
logvr(v, r).Debug("HTTP stream connection opened")
defer logvr(v, r).Debug("HTTP stream connection closed")
logvr(v, r).Tag(tagSubscribe).Debug("HTTP stream connection opened")
defer logvr(v, r).Tag(tagSubscribe).Debug("HTTP stream connection closed")
if !v.SubscriptionAllowed() {
return errHTTPTooManyRequestsLimitSubscriptions
}
@ -1025,7 +1027,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
case <-r.Context().Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
logvr(v, r).Trace("Sending keepalive message")
logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message")
v.Keepalive()
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err
@ -1283,70 +1285,86 @@ func (s *Server) topicFromID(id string) (*topic, error) {
}
func (s *Server) execManager() {
log.Tag(tagManager).Debug("Starting manager")
defer log.Tag(tagManager).Debug("Finished manager")
// WARNING: Make sure to only selectively lock with the mutex, and be aware that this
// there is no mutex for the entire function.
// Expire visitors from rate visitors map
s.mu.Lock()
staleVisitors := 0
for ip, v := range s.visitors {
if v.Stale() {
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
delete(s.visitors, ip)
staleVisitors++
}
}
s.mu.Unlock()
log.Tag(tagManager).Field("stale_visitors", staleVisitors).Debug("Deleted %d stale visitor(s)", staleVisitors)
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for ip, v := range s.visitors {
if v.Stale() {
log.Tag(tagManager).With(v).Trace("Deleting stale visitor")
delete(s.visitors, ip)
staleVisitors++
}
}
}).
Field("stale_visitors", staleVisitors).
Debug("Deleted %d stale visitor(s)", staleVisitors)
// Delete expired user tokens and users
if s.userManager != nil {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
}
log.
Tag(tagManager).
Timing(func() {
if err := s.userManager.RemoveExpiredTokens(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error expiring user tokens")
}
if err := s.userManager.RemoveDeletedUsers(); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting soft-deleted users")
}
}).
Debug("Removed expired tokens and users")
}
// Delete expired attachments
if s.fileCache != nil {
ids, err := s.messageCache.AttachmentsExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
} else if len(ids) > 0 {
if log.Tag(tagManager).IsDebug() {
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
}
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired attachments to delete")
}
log.
Tag(tagManager).
Timing(func() {
ids, err := s.messageCache.AttachmentsExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired attachments")
} else if len(ids) > 0 {
if log.Tag(tagManager).IsDebug() {
log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", "))
}
if err := s.fileCache.Remove(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments")
}
if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired attachments to delete")
}
}).
Debug("Deleted expired attachments")
}
// Prune messages
log.Tag(tagManager).Debug("Manager: Pruning messages")
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
} else if len(expiredMessageIDs) > 0 {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
}
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired messages to delete")
}
log.
Tag(tagManager).
Timing(func() {
expiredMessageIDs, err := s.messageCache.MessagesExpired()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages")
} else if len(expiredMessageIDs) > 0 {
if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages")
}
if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
log.Tag(tagManager).Err(err).Warn("Error marking attachments deleted")
}
} else {
log.Tag(tagManager).Debug("No expired messages to delete")
}
}).
Debug("Pruned messages")
// Message count per topic
var messagesCached int
@ -1360,20 +1378,26 @@ func (s *Server) execManager() {
}
// Remove subscriptions without subscribers
s.mu.Lock()
var subscribers int
for _, t := range s.topics {
subs := t.SubscribersCount()
log.Tag(tagManager).Trace("- topic %s: %d subscribers", t.ID, subs)
msgs, exists := messageCounts[t.ID]
if subs == 0 && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
delete(s.topics, t.ID)
continue
}
subscribers += subs
}
s.mu.Unlock()
var emptyTopics, subscribers int
log.
Tag(tagManager).
Timing(func() {
s.mu.Lock()
defer s.mu.Unlock()
for _, t := range s.topics {
subs := t.SubscribersCount()
log.Tag(tagManager).Trace("- topic %s: %d subscribers", t.ID, subs)
msgs, exists := messageCounts[t.ID]
if subs == 0 && (!exists || msgs == 0) {
log.Tag(tagManager).Trace("Deleting empty topic %s", t.ID)
emptyTopics++
delete(s.topics, t.ID)
continue
}
subscribers += subs
}
}).
Debug("Removed %d empty topic(s)", emptyTopics)
// Mail stats
var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
@ -1407,6 +1431,10 @@ func (s *Server) execManager() {
Info("Server stats")
}
func (s *Server) expireVisitors() {
}
func (s *Server) runSMTPServer() error {
s.smtpServerBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpServerBackend)
@ -1424,7 +1452,10 @@ func (s *Server) runManager() {
for {
select {
case <-time.After(s.config.ManagerInterval):
s.execManager()
log.
Tag(tagManager).
Timing(s.execManager).
Debug("Manager finished")
case <-s.closeChan:
return
}

View file

@ -314,7 +314,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
}
}
logvr(v, r).Tag(tagAccount).Debug("Changing account settings for user %s", u.Name)
if err := s.userManager.ChangeSettings(u); err != nil {
if err := s.userManager.ChangeSettings(u.ID, prefs); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
@ -338,7 +338,8 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
}
if newSubscription.ID == "" {
newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
u.Prefs.Subscriptions = append(u.Prefs.Subscriptions, newSubscription)
prefs := u.Prefs
prefs.Subscriptions = append(prefs.Subscriptions, newSubscription)
logvr(v, r).
Tag(tagAccount).
Fields(log.Context{
@ -346,7 +347,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
"topic": newSubscription.Topic,
}).
Debug("Adding subscription for user %s", u.Name)
if err := s.userManager.ChangeSettings(u); err != nil {
if err := s.userManager.ChangeSettings(u.ID, prefs); err != nil {
return err
}
}
@ -367,8 +368,9 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
if u.Prefs == nil || u.Prefs.Subscriptions == nil {
return errHTTPNotFound
}
prefs := u.Prefs
var subscription *user.Subscription
for _, sub := range u.Prefs.Subscriptions {
for _, sub := range prefs.Subscriptions {
if sub.ID == subscriptionID {
sub.DisplayName = updatedSubscription.DisplayName
subscription = sub
@ -386,7 +388,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
"display_name": subscription.DisplayName,
}).
Debug("Changing subscription for user %s", u.Name)
if err := s.userManager.ChangeSettings(u); err != nil {
if err := s.userManager.ChangeSettings(u.ID, prefs); err != nil {
return err
}
return s.writeJSON(w, subscription)
@ -417,8 +419,9 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
}
}
if len(newSubscriptions) < len(u.Prefs.Subscriptions) {
u.Prefs.Subscriptions = newSubscriptions
if err := s.userManager.ChangeSettings(u); err != nil {
prefs := u.Prefs
prefs.Subscriptions = newSubscriptions
if err := s.userManager.ChangeSettings(u.ID, prefs); err != nil {
return err
}
}

View file

@ -724,5 +724,5 @@ func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
time.Sleep(300 * time.Millisecond)
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, int64(0), u.Stats.Messages) // v.EnqueueStats had run!
require.Equal(t, int64(0), u.Stats.Messages) // v.EnqueueUserStats had run!
}

View file

@ -938,7 +938,7 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
u, err := s.userManager.User("phil")
require.Nil(t, err)
s.userManager.EnqueueStats(u.ID, &user.Stats{
s.userManager.EnqueueUserStats(u.ID, &user.Stats{
Messages: 123456,
Emails: 999,
})

View file

@ -88,7 +88,7 @@ func (t *topic) CancelSubscribers(exceptUserID string) {
defer t.mu.Unlock()
for _, s := range t.subscribers {
if s.userID != exceptUserID {
log.Field("topic", t.ID).Trace("Canceling subscriber %s", s.userID)
log.Tag(tagSubscribe).Field("topic", t.ID).Debug("Canceling subscriber %s", s.userID)
s.cancel()
}
}

View file

@ -27,7 +27,7 @@ const (
)
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
// values (token bucket).
// values (token bucket). This is only used to increase the values in server.yml, never decrease them.
//
// Example: Assuming a user.Tier's MessageLimit is 10,000:
// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
@ -59,7 +59,7 @@ type visitor struct {
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
authLimiter *rate.Limiter // Limiter for incorrect login attempts
authLimiter *rate.Limiter // Limiter for incorrect login attempts, may be nil
firebase time.Time // Next allowed Firebase message
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
mu sync.Mutex
@ -360,7 +360,7 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
v.authLimiter = nil // Users are already logged in, no need to limit requests
}
if enqueueUpdate && v.user != nil {
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{
go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{
Messages: messages,
Emails: emails,
})

View file

@ -1,3 +1,4 @@
// Package user deals with authentication and authorization against topics
package user
import (
@ -28,7 +29,7 @@ const (
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 20 // Only keep this many tokens in the table per user
tagManager = "user_manager"
tag = "user_manager"
)
// Default constants that may be overridden by configs
@ -47,7 +48,7 @@ var (
const (
createTablesQueriesNoTx = `
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
@ -89,7 +90,7 @@ const (
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
@ -109,7 +110,7 @@ const (
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
@ -121,7 +122,7 @@ const (
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
WHERE u.id = ?
`
selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
@ -151,12 +152,12 @@ const (
`
insertUserQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES (?, ?, ?, ?, ?, ?)
`
selectUsernamesQuery = `
SELECT user
FROM user
SELECT user
FROM user
ORDER BY
CASE role
WHEN 'admin' THEN 1
@ -166,7 +167,7 @@ const (
`
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
@ -174,15 +175,15 @@ const (
deleteUserQuery = `DELETE FROM user WHERE user = ?`
upsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
ON CONFLICT (user_id, topic)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
`
selectUserAccessQuery = `
SELECT topic, read, write
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY write DESC, read DESC, topic
`
selectUserReservationsQuery = `
@ -201,9 +202,9 @@ const (
selectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
AND topic = ?
AND topic = ?
`
selectOtherAccessCountQuery = `
SELECT COUNT(*)
@ -213,13 +214,13 @@ const (
`
deleteAllAccessQuery = `DELETE FROM user_access`
deleteUserAccessQuery = `
DELETE FROM user_access
DELETE FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
deleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
`
@ -239,7 +240,7 @@ const (
SELECT user_id, token
FROM user_token
WHERE user_id = ?
ORDER BY expires DESC
ORDER BY expires DESC
LIMIT ?
)
`
@ -249,7 +250,7 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
updateTierQuery = `
UPDATE tier
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
WHERE code = ?
`
@ -272,7 +273,7 @@ const (
deleteTierQuery = `DELETE FROM tier WHERE code = ?`
updateBillingQuery = `
UPDATE user
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ?
`
@ -291,7 +292,7 @@ const (
`
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
migrate1To2InsertUserNoTx = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
migrate1To2InsertFromOldTablesAndDropNoTx = `
@ -305,6 +306,12 @@ const (
`
)
var (
migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1,
}
)
// Manager is an implementation of Manager. It stores users and access control list
// in a SQLite database.
type Manager struct {
@ -350,15 +357,15 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
}
user, err := a.User(username)
if err != nil {
log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if user.Deleted {
log.Tag(tagManager).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
return nil, ErrUnauthenticated
}
return user, nil
@ -372,7 +379,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
}
user, err := a.userByToken(token)
if err != nil {
log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated
}
user.Token = token
@ -532,12 +539,12 @@ func (a *Manager) RemoveDeletedUsers() error {
}
// ChangeSettings persists the user settings
func (a *Manager) ChangeSettings(user *User) error {
prefs, err := json.Marshal(user.Prefs)
func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
b, err := json.Marshal(prefs)
if err != nil {
return err
}
if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
return err
}
return nil
@ -554,9 +561,9 @@ func (a *Manager) ResetStats() error {
return nil
}
// EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// batches at a regular interval
func (a *Manager) EnqueueStats(userID string, stats *Stats) {
func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
a.mu.Lock()
defer a.mu.Unlock()
a.statsQueue[userID] = stats
@ -574,10 +581,10 @@ func (a *Manager) asyncQueueWriter(interval time.Duration) {
ticker := time.NewTicker(interval)
for range ticker.C {
if err := a.writeUserStatsQueue(); err != nil {
log.Tag(tagManager).Err(err).Warn("Writing user stats queue failed")
log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
}
if err := a.writeTokenUpdateQueue(); err != nil {
log.Tag(tagManager).Err(err).Warn("Writing token update queue failed")
log.Tag(tag).Err(err).Warn("Writing token update queue failed")
}
}
}
@ -586,7 +593,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.mu.Lock()
if len(a.statsQueue) == 0 {
a.mu.Unlock()
log.Tag(tagManager).Trace("No user stats updates to commit")
log.Tag(tag).Trace("No user stats updates to commit")
return nil
}
statsQueue := a.statsQueue
@ -597,10 +604,10 @@ func (a *Manager) writeUserStatsQueue() error {
return err
}
defer tx.Rollback()
log.Tag(tagManager).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue {
log.
Tag(tagManager).
Tag(tag).
Fields(log.Context{
"user_id": userID,
"messages_count": update.Messages,
@ -618,7 +625,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
a.mu.Lock()
if len(a.tokenQueue) == 0 {
a.mu.Unlock()
log.Tag(tagManager).Trace("No token updates to commit")
log.Tag(tag).Trace("No token updates to commit")
return nil
}
tokenQueue := a.tokenQueue
@ -629,9 +636,9 @@ func (a *Manager) writeTokenUpdateQueue() error {
return err
}
defer tx.Rollback()
log.Tag(tagManager).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Tag(tagManager).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
return err
}
@ -718,7 +725,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
return err
}
defer tx.Rollback()
if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
return err
}
if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
@ -1012,7 +1019,6 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
// CheckAllowAccess tests if a user may create an access control entry for the given topic.
// If there are any ACL entries that are not owned by the user, an error is returned.
// FIXME is this the same as HasReservation?
func (a *Manager) CheckAllowAccess(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument
@ -1275,10 +1281,18 @@ func setupDB(db *sql.DB) error {
// Do migrations
if schemaVersion == currentSchemaVersion {
return nil
} else if schemaVersion == 1 {
return migrateFrom1(db)
} else if schemaVersion > currentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
}
return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
for i := schemaVersion; i < currentSchemaVersion; i++ {
fn, ok := migrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db); err != nil {
return err
}
}
return nil
}
func setupNewDB(db *sql.DB) error {
@ -1292,7 +1306,7 @@ func setupNewDB(db *sql.DB) error {
}
func migrateFrom1(db *sql.DB) error {
log.Tag(tagManager).Info("Migrating user database schema: from 1 to 2")
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
tx, err := db.Begin()
if err != nil {
return err
@ -1339,7 +1353,7 @@ func migrateFrom1(db *sql.DB) error {
if err := tx.Commit(); err != nil {
return err
}
return nil // Update this when a new version is added
return nil
}
func nullString(s string) sql.NullString {

View file

@ -562,7 +562,7 @@ func TestManager_EnqueueStats(t *testing.T) {
require.Nil(t, err)
require.Equal(t, int64(0), u.Stats.Messages)
require.Equal(t, int64(0), u.Stats.Emails)
a.EnqueueStats(u.ID, &Stats{
a.EnqueueUserStats(u.ID, &Stats{
Messages: 11,
Emails: 2,
})
@ -595,7 +595,7 @@ func TestManager_ChangeSettings(t *testing.T) {
require.Nil(t, u.Prefs.Language)
// Save with new settings
u.Prefs = &Prefs{
prefs := &Prefs{
Language: util.String("de"),
Notification: &NotificationPrefs{
Sound: util.String("ding"),
@ -610,7 +610,7 @@ func TestManager_ChangeSettings(t *testing.T) {
},
},
}
require.Nil(t, a.ChangeSettings(u))
require.Nil(t, a.ChangeSettings(u.ID, prefs))
// Read again
u, err = a.User("ben")

View file

@ -1,4 +1,3 @@
// Package user deals with authentication and authorization against topics
package user
import (

View file

@ -234,7 +234,7 @@ func FormatSize(b int64) string {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp])
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
}
// ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the

View file

@ -176,24 +176,25 @@
"account_basics_password_dialog_current_password_label": "Current password",
"account_basics_password_dialog_new_password_label": "New password",
"account_basics_password_dialog_confirm_password_label": "Confirm password",
"account_basics_password_dialog_button_cancel": "Cancel",
"account_basics_password_dialog_button_submit": "Change password",
"account_basics_password_dialog_current_password_incorrect": "Password incorrect",
"account_usage_title": "Usage",
"account_usage_of_limit": "of {{limit}}",
"account_usage_unlimited": "Unlimited",
"account_usage_limits_reset_daily": "Usage limits are reset daily at midnight (UTC)",
"account_usage_tier_title": "Account type",
"account_usage_tier_description": "Your account's power level",
"account_usage_tier_admin": "Admin",
"account_usage_tier_basic": "Basic",
"account_usage_tier_free": "Free",
"account_usage_tier_upgrade_button": "Upgrade to Pro",
"account_usage_tier_change_button": "Change",
"account_usage_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
"account_usage_tier_payment_overdue": "Your payment is overdue. Please update your payment method, or your account will be downgraded soon.",
"account_usage_tier_canceled_subscription": "Your subscription was canceled and will be downgraded to a free account on {{date}}.",
"account_usage_manage_billing_button": "Manage billing",
"account_basics_tier_title": "Account type",
"account_basics_tier_description": "Your account's power level",
"account_basics_tier_admin": "Admin",
"account_basics_tier_admin_suffix_with_tier": "(with {{tier}} tier)",
"account_basics_tier_admin_suffix_no_tier": "(no tier)",
"account_basics_tier_basic": "Basic",
"account_basics_tier_free": "Free",
"account_basics_tier_upgrade_button": "Upgrade to Pro",
"account_basics_tier_change_button": "Change",
"account_basics_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
"account_basics_tier_payment_overdue": "Your payment is overdue. Please update your payment method, or your account will be downgraded soon.",
"account_basics_tier_canceled_subscription": "Your subscription was canceled and will be downgraded to a free account on {{date}}.",
"account_basics_tier_manage_billing_button": "Manage billing",
"account_usage_messages_title": "Published messages",
"account_usage_emails_title": "Emails sent",
"account_usage_reservations_title": "Reserved topics",
@ -204,7 +205,7 @@
"account_usage_cannot_create_portal_session": "Unable to open billing portal",
"account_delete_title": "Delete account",
"account_delete_description": "Permanently delete your account",
"account_delete_dialog_description": "This will permanently delete your account, including all data that is stored on the server. If you really want to proceed, please confirm with your password in the box below.",
"account_delete_dialog_description": "This will permanently delete your account, including all data that is stored on the server. After deletion, your username will be unavailable for 7 days. If you really want to proceed, please confirm with your password in the box below.",
"account_delete_dialog_label": "Password",
"account_delete_dialog_button_cancel": "Cancel",
"account_delete_dialog_button_submit": "Permanently delete account",

View file

@ -27,6 +27,7 @@ class AccountApi {
constructor() {
this.timer = null;
this.listener = null; // Fired when account is fetched from remote
this.tiers = null; // Cached
}
registerListener(listener) {
@ -148,11 +149,7 @@ class AccountApi {
console.log(`[AccountApi] Extending user access token ${url}`);
await fetchOrThrow(url, {
method: "PATCH",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({
token: session.token(),
expires: Math.floor(Date.now() / 1000) + 6220800 // FIXME
})
headers: withBearerAuth({}, session.token())
});
}
@ -239,10 +236,14 @@ class AccountApi {
}
async billingTiers() {
if (this.tiers) {
return this.tiers;
}
const url = tiersUrl(config.base_url);
console.log(`[AccountApi] Fetching billing tiers`);
const response = await fetchOrThrow(url); // No auth needed!
return await response.json(); // May throw SyntaxError
this.tiers = await response.json(); // May throw SyntaxError
return this.tiers;
}
async createBillingSubscription(tier) {

View file

@ -198,7 +198,7 @@ const ChangePasswordDialog = (props) => {
/>
</DialogContent>
<DialogFooter status={error}>
<Button onClick={props.onClose}>{t("account_basics_password_dialog_button_cancel")}</Button>
<Button onClick={props.onClose}>{t("common_cancel")}</Button>
<Button
onClick={handleDialogSubmit}
disabled={newPassword.length === 0 || currentPassword.length === 0 || newPassword !== confirmPassword}
@ -242,10 +242,10 @@ const AccountType = () => {
let accountType;
if (account.role === Role.ADMIN) {
const tierSuffix = (account.tier) ? `(with ${account.tier.name} tier)` : `(no tier)`;
accountType = `${t("account_usage_tier_admin")} ${tierSuffix}`;
const tierSuffix = (account.tier) ? t("account_basics_tier_admin_suffix_with_tier", { tier: account.tier.name }) : t("account_basics_tier_admin_suffix_no_tier");
accountType = `${t("account_basics_tier_admin")} ${tierSuffix}`;
} else if (!account.tier) {
accountType = (config.enable_payments) ? t("account_usage_tier_free") : t("account_usage_tier_basic");
accountType = (config.enable_payments) ? t("account_basics_tier_free") : t("account_basics_tier_basic");
} else {
accountType = account.tier.name;
}
@ -253,13 +253,13 @@ const AccountType = () => {
return (
<Pref
alignTop={account.billing?.status === SubscriptionStatus.PAST_DUE || account.billing?.cancel_at > 0}
title={t("account_usage_tier_title")}
description={t("account_usage_tier_description")}
title={t("account_basics_tier_title")}
description={t("account_basics_tier_description")}
>
<div>
{accountType}
{account.billing?.paid_until && !account.billing?.cancel_at &&
<Tooltip title={t("account_usage_tier_paid_until", { date: formatShortDate(account.billing?.paid_until) })}>
<Tooltip title={t("account_basics_tier_paid_until", { date: formatShortDate(account.billing?.paid_until) })}>
<span><InfoIcon/></span>
</Tooltip>
}
@ -270,7 +270,7 @@ const AccountType = () => {
startIcon={<CelebrationIcon sx={{ color: "#55b86e" }}/>}
onClick={handleUpgradeClick}
sx={{ml: 1}}
>{t("account_usage_tier_upgrade_button")}</Button>
>{t("account_basics_tier_upgrade_button")}</Button>
}
{config.enable_payments && account.role === Role.USER && account.billing?.subscription &&
<Button
@ -278,7 +278,7 @@ const AccountType = () => {
size="small"
onClick={handleUpgradeClick}
sx={{ml: 1}}
>{t("account_usage_tier_change_button")}</Button>
>{t("account_basics_tier_change_button")}</Button>
}
{config.enable_payments && account.role === Role.USER && account.billing?.customer &&
<Button
@ -286,19 +286,21 @@ const AccountType = () => {
size="small"
onClick={handleManageBilling}
sx={{ml: 1}}
>{t("account_usage_manage_billing_button")}</Button>
>{t("account_basics_tier_manage_billing_button")}</Button>
}
{config.enable_payments &&
<UpgradeDialog
key={`upgradeDialogFromAccount${upgradeDialogKey}`}
open={upgradeDialogOpen}
onCancel={() => setUpgradeDialogOpen(false)}
/>
}
<UpgradeDialog
key={`upgradeDialogFromAccount${upgradeDialogKey}`}
open={upgradeDialogOpen}
onCancel={() => setUpgradeDialogOpen(false)}
/>
</div>
{account.billing?.status === SubscriptionStatus.PAST_DUE &&
<Alert severity="error" sx={{mt: 1}}>{t("account_usage_tier_payment_overdue")}</Alert>
<Alert severity="error" sx={{mt: 1}}>{t("account_basics_tier_payment_overdue")}</Alert>
}
{account.billing?.cancel_at > 0 &&
<Alert severity="warning" sx={{mt: 1}}>{t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
<Alert severity="warning" sx={{mt: 1}}>{t("account_basics_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
}
<Portal>
<Snackbar

View file

@ -212,7 +212,7 @@ const TierCard = (props) => {
}}>{labelText}</div>
}
<Typography variant="h5" component="div">
{tier.name || t("account_usage_tier_free")}
{tier.name || t("account_basics_tier_free")}
</Typography>
<List dense>
{tier.limits.reservations > 0 && <FeatureItem>{t("account_upgrade_dialog_tier_features_reservations", { reservations: tier.limits.reservations })}</FeatureItem>}

View file

@ -1,8 +1,6 @@
import config from "../app/config";
import {shortUrl} from "../app/utils";
// Remember to also update the "disallowedTopics" list!
const routes = {
login: "/login",
signup: "/signup",