mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-24 02:32:48 +01:00
140 lines
2.8 KiB
Go
140 lines
2.8 KiB
Go
package server
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"heckel.io/ntfy/util"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
|
|
errInvalidFileID = errors.New("invalid file ID")
|
|
errFileExists = errors.New("file exists")
|
|
)
|
|
|
|
type fileCache struct {
|
|
dir string
|
|
totalSizeCurrent int64
|
|
totalSizeLimit int64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
|
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
|
return nil, err
|
|
}
|
|
size, err := dirSize(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &fileCache{
|
|
dir: dir,
|
|
totalSizeCurrent: size,
|
|
totalSizeLimit: totalSizeLimit,
|
|
}, nil
|
|
}
|
|
|
|
func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
|
if !fileIDRegex.MatchString(id) {
|
|
return 0, errInvalidFileID
|
|
}
|
|
file := filepath.Join(c.dir, id)
|
|
if _, err := os.Stat(file); err == nil {
|
|
return 0, errFileExists
|
|
}
|
|
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer f.Close()
|
|
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
|
limitWriter := util.NewLimitWriter(f, limiters...)
|
|
size, err := io.Copy(limitWriter, in)
|
|
if err != nil {
|
|
os.Remove(file)
|
|
return 0, err
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
os.Remove(file)
|
|
return 0, err
|
|
}
|
|
c.mu.Lock()
|
|
c.totalSizeCurrent += size
|
|
c.mu.Unlock()
|
|
return size, nil
|
|
}
|
|
|
|
func (c *fileCache) Remove(ids ...string) error {
|
|
for _, id := range ids {
|
|
if !fileIDRegex.MatchString(id) {
|
|
return errInvalidFileID
|
|
}
|
|
file := filepath.Join(c.dir, id)
|
|
_ = os.Remove(file) // Best effort delete
|
|
}
|
|
size, err := dirSize(c.dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.mu.Lock()
|
|
c.totalSizeCurrent = size
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// Expired returns a list of file IDs for expired files
|
|
func (c *fileCache) Expired(olderThan time.Time) ([]string, error) {
|
|
entries, err := os.ReadDir(c.dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ids []string
|
|
for _, e := range entries {
|
|
info, err := e.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) {
|
|
ids = append(ids, e.Name())
|
|
}
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
func (c *fileCache) Size() int64 {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.totalSizeCurrent
|
|
}
|
|
|
|
func (c *fileCache) Remaining() int64 {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
|
if remaining < 0 {
|
|
return 0
|
|
}
|
|
return remaining
|
|
}
|
|
|
|
func dirSize(dir string) (int64, error) {
|
|
entries, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
var size int64
|
|
for _, e := range entries {
|
|
info, err := e.Info()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
size += info.Size()
|
|
}
|
|
return size, nil
|
|
}
|