From 7d8407871c081185b227bd9e14c4e84354cf5dc9 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 23 Apr 2024 22:44:04 -0400 Subject: [PATCH] WIP Postgres message cache --- go.mod | 3 +- go.sum | 73 +- server/message_cache.go | 701 ++++-------------- server/message_cache_pg.go | 179 +++++ server/message_cache_sqlite.go | 542 ++++++++++++++ server/message_cache_sqlite_test.go | 254 +++++++ server/message_cache_test.go | 1025 ++++++++++----------------- server/server.go | 10 +- server/server_test.go | 6 +- server/visitor.go | 4 +- 10 files changed, 1488 insertions(+), 1309 deletions(-) create mode 100644 server/message_cache_pg.go create mode 100644 server/message_cache_sqlite.go create mode 100644 server/message_cache_sqlite_test.go diff --git a/go.mod b/go.mod index e67e388a..3687b2b2 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.14.0 github.com/SherClockHolmes/webpush-go v1.3.0 + github.com/lib/pq v1.10.9 github.com/microcosm-cc/bluemonday v1.0.26 github.com/prometheus/client_golang v1.19.0 github.com/stripe/stripe-go/v74 v74.30.0 @@ -41,7 +42,6 @@ require ( cloud.google.com/go v0.112.2 // indirect cloud.google.com/go/auth v0.3.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute v1.25.1 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/iam v1.1.7 // indirect cloud.google.com/go/longrunning v0.5.6 // indirect @@ -81,7 +81,6 @@ require ( golang.org/x/net v0.24.0 // indirect golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine/v2 v2.0.6 // indirect google.golang.org/genproto v0.0.0-20240415180920-8c6c420018be // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be // indirect diff --git a/go.sum b/go.sum index 4c9cd065..1e63474e 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,10 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM= -cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4= cloud.google.com/go v0.112.2 h1:ZaGT6LiG7dBzi6zNOvVZwacaXlmf3lRqnC4DQzqyRQw= cloud.google.com/go v0.112.2/go.mod h1:iEqjp//KquGIJV/m+Pk3xecgKNhV+ry+vVTsy4TbDms= cloud.google.com/go/auth v0.3.0 h1:PRyzEpGfx/Z9e8+lHsbkoUVXD0gnu4MNmm7Gp8TQNIs= cloud.google.com/go/auth v0.3.0/go.mod h1:lBv6NKTWp8E3LPzmO1TbiiRKc4drLOfHsgmlH9ogv5w= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU= -cloud.google.com/go/compute v1.25.1/go.mod h1:oopOIR53ly6viBYxaDhBfJwzUAxf1zE//uf3IB011ls= -cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= -cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/firestore v1.15.0 h1:/k8ppuWOtNuDHt2tsRV42yI21uaGnKDEQnRFeBpbFF8= @@ -19,12 +13,8 @@ cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM= cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA= cloud.google.com/go/longrunning v0.5.6 h1:xAe8+0YaWoCKr9t1+aWe+OeQgN/iJK1fEgZSXmjuEaE= cloud.google.com/go/longrunning v0.5.6/go.mod h1:vUaDrWYOMKRuhiv6JBnn49YxCPz2Ayn9GqyjaBT8/mA= -cloud.google.com/go/storage v1.39.1 h1:MvraqHKhogCOTXTlct/9C3K3+Uy2jBmFYb3/Sp6dVtY= -cloud.google.com/go/storage v1.39.1/go.mod h1:xK6xZmxZmo+fyP7+DEF6FhNc24/JAe95OLyOHCXFH1o= cloud.google.com/go/storage v1.40.0 h1:VEpDQV5CJxFmJ6ueWNsKxcr1QAYOXEgxDa+sBbJahPw= cloud.google.com/go/storage v1.40.0/go.mod h1:Rrj7/hKlG87BLqDJYtwR0fbPld8uJPbQ2ucUMY7Ir0g= -firebase.google.com/go/v4 v4.13.0 h1:meFz9nvDNh/FDyrEykoAzSfComcQbmnQSjoHrePRqeI= -firebase.google.com/go/v4 v4.13.0/go.mod h1:e1/gaR6EnbQfsmTnAMx1hnz+ninJIrrr/RAh59Tpfn8= firebase.google.com/go/v4 v4.14.0 h1:Tc9jWzMUApUFUA5UUx/HcBeZ+LPjlhG2vNRfWJrcMwU= firebase.google.com/go/v4 v4.14.0/go.mod h1:pLATyL6xH2o9AMe7rqHdmmOUE/Ph7wcwepIs+uiEKPg= github.com/AlekSi/pointer v1.2.0 h1:glcy/gc4h8HnG2Z3ZECSzZ1IX1x2JxRVuDzaJwQE0+w= @@ -41,8 +31,6 @@ github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd3 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -91,7 +79,6 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -122,6 +109,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= @@ -135,18 +124,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= -github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= -github.com/prometheus/common v0.50.0 h1:YSZE6aa9+luNa2da6/Tik0q0A5AbR+U003TItK57CPQ= -github.com/prometheus/common v0.50.0/go.mod h1:wHFBCEVWVmHMUpg7pYcOm2QUR/ocQdYSJVQJKnHc3xQ= -github.com/prometheus/common v0.51.1 h1:eIjN50Bwglz6a/c3hAgSMcofL3nD+nFQkV6Dd4DsQCw= -github.com/prometheus/common v0.51.1/go.mod h1:lrWtQx+iDfn2mbH5GUzlH9TSHyfZpHkSiG1W7y3sF2Q= github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE= github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= -github.com/prometheus/procfs v0.13.0 h1:GqzLlQyfsPbaEHaQkO7tbDlriv/4o5Hudv6OXHGKX7o= -github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43ZKY6tow0Y1g= github.com/prometheus/procfs v0.14.0 h1:Lw4VdGGoKEZilJsayHf0B+9YgLGREba2C6xr+Fdfq6s= github.com/prometheus/procfs v0.14.0/go.mod h1:XL+Iwz8k8ZabyZfMFHPiilCniixqQarAy5Mu67pHlNQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -155,15 +136,14 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stripe/stripe-go/v74 v74.30.0 h1:0Kf0KkeFnY7iRhOwvTerX0Ia1BRw+eV1CVJ51mGYAUY= github.com/stripe/stripe-go/v74 v74.30.0/go.mod h1:f9L6LvaXa35ja7eyvP6GQswoaIPaBRvGAimAO+udbBw= @@ -174,34 +154,22 @@ github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQut github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.50.0 h1:zvpPXY7RfYAGSdYQLjp6zxdJNSYD/+FFoCTQN9IPxBs= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.50.0/go.mod h1:BMn8NB1vsxTljvuorms2hyOs8IBuuBEq0pl7ltOfy30= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0 h1:cEPbyTSEHlQR89XVlyo78gqluF8Y3oMeBkXGWzQsfXY= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.50.0/go.mod h1:DKdbWcT4GH1D0Y3Sqt/PFXt2naRKDWtU+eE6oLdFNA8= -go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= -go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= go.opentelemetry.io/otel v1.25.0 h1:gldB5FfhRl7OJQbUHt/8s0a7cE8fbsPAtdpRaApKy4k= go.opentelemetry.io/otel v1.25.0/go.mod h1:Wa2ds5NOXEMkCmUou1WA7ZBfLTHWIsp034OVD7AO+Vg= -go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= -go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= go.opentelemetry.io/otel/metric v1.25.0 h1:LUKbS7ArpFL/I2jJHdJcqMGxkRdxpPHE0VU/D4NuEwA= go.opentelemetry.io/otel/metric v1.25.0/go.mod h1:rkDLUSd2lC5lq2dFNrX9LGAbINP5B7WBkC78RXCpH5s= go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= -go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= go.opentelemetry.io/otel/trace v1.25.0 h1:tqukZGLwQYRIFtSQM2u2+yfMVTgGVeqRLPUYx1Dq6RM= go.opentelemetry.io/otel/trace v1.25.0/go.mod h1:hCCs70XM/ljO+BeQkyFnbK28SBIJ/Emuha+ccrCRT7I= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -222,13 +190,9 @@ golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= -golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg= golang.org/x/oauth2 v0.19.0/go.mod h1:vYi7skDa1x015PmRRYZ7+s1cWyPgrPiSYRe4rnsexc8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -236,8 +200,6 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -251,16 +213,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -286,39 +244,19 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= -google.golang.org/api v0.170.0 h1:zMaruDePM88zxZBG+NG8+reALO2rfLhe/JShitLyT48= -google.golang.org/api v0.170.0/go.mod h1:/xql9M2btF85xac/VAm4PsLMTLVGUOpq4BE9R8jyNy8= -google.golang.org/api v0.171.0 h1:w174hnBPqut76FzW5Qaupt7zY8Kql6fiVjgys4f58sU= -google.golang.org/api v0.171.0/go.mod h1:Hnq5AHm4OTMt2BUVjael2CWZFD6vksJdWCWiUAmjC9o= google.golang.org/api v0.176.1 h1:DJSXnV6An+NhJ1J+GWtoF2nHEuqB1VNoTfnIbjNvwD4= google.golang.org/api v0.176.1/go.mod h1:j2MaSDYcvYV1lkZ1+SMW4IeF90SrEyFA+tluDYWRrFg= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/appengine/v2 v2.0.5 h1:4C+F3Cd3L2nWEfSmFEZDPjQvDwL8T0YCeZBysZifP3k= -google.golang.org/appengine/v2 v2.0.5/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237 h1:PgNlNSx2Nq2/j4juYzQBG0/Zdr+WP4z5N01Vk4VYBCY= -google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237/go.mod h1:9sVD8c25Af3p0rGs7S7LLsxWKFiJt/65LdSyqXBkX/Y= -google.golang.org/genproto v0.0.0-20240325203815-454cdb8f5daa h1:ePqxpG3LVx+feAUOx8YmR5T7rc0rdzK8DyxM8cQ9zq0= -google.golang.org/genproto v0.0.0-20240325203815-454cdb8f5daa/go.mod h1:CnZenrTdRJb7jc+jOm0Rkywq+9wh0QC4U8tyiRbEPPM= google.golang.org/genproto v0.0.0-20240415180920-8c6c420018be h1:g4aX8SUFA8V5F4LrSY5EclyGYw1OZN4HS1jTyjB9ZDc= google.golang.org/genproto v0.0.0-20240415180920-8c6c420018be/go.mod h1:FeSdT5fk+lkxatqJP38MsUicGqHax5cLtmy/6TAuxO4= -google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:RFiFrvy37/mpSpdySBDrUdipW/dHwsRwh3J3+A9VgT4= -google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE= -google.golang.org/genproto/googleapis/api v0.0.0-20240325203815-454cdb8f5daa h1:Jt1XW5PaLXF1/ePZrznsh/aAUvI7Adfc3LY1dAKlzRs= -google.golang.org/genproto/googleapis/api v0.0.0-20240325203815-454cdb8f5daa/go.mod h1:K4kfzHtI0kqWA79gecJarFtDn/Mls+GxQcg3Zox91Ac= google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be h1:Zz7rLWqp0ApfsR/l7+zSHhY3PMiH2xqgxlfYfAfNpoU= google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be/go.mod h1:dvdCTIoAGbkWbcIKBniID56/7XHTt6WfxXNMxuziJ+w= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa h1:RBgMaUMP+6soRkik4VoN8ojR2nex2TqZwjSSogic+eo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be h1:LG9vZxsWGOmUKieR8wPAUR3u3MpnYFQZROPIMaXh7/A= google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= @@ -326,8 +264,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk= -google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -340,7 +276,6 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= diff --git a/server/message_cache.go b/server/message_cache.go index f0744abb..a1cda19b 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -4,332 +4,81 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" "net/netip" "strings" "time" - - _ "github.com/mattn/go-sqlite3" // SQLite driver - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/util" ) -var ( - errUnexpectedMessageType = errors.New("unexpected message type") - errMessageNotFound = errors.New("message not found") - errNoRows = errors.New("no rows found") -) - -// Messages cache -const ( - createMessagesTableQuery = ` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - time INT NOT NULL, - expires INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - icon TEXT NOT NULL, - actions TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - attachment_deleted INT NOT NULL, - sender TEXT NOT NULL, - user TEXT NOT NULL, - content_type TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); - CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); - CREATE INDEX IF NOT EXISTS idx_user ON messages (user); - CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); - CREATE TABLE IF NOT EXISTS stats ( - key TEXT PRIMARY KEY, - value INT - ); - INSERT INTO stats (key, value) VALUES ('messages', 0); - COMMIT; - ` - insertMessageQuery = ` - INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` - updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` - selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics - selectMessagesByIDQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE mid = ? - ` - selectMessagesSinceTimeQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND time >= ? AND published = 1 - ORDER BY time, id - ` - selectMessagesSinceTimeIncludeScheduledQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND time >= ? - ORDER BY time, id - ` - selectMessagesSinceIDQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND id > ? AND published = 1 - ORDER BY time, id - ` - selectMessagesSinceIDIncludeScheduledQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND (id > ? OR published = 0) - ORDER BY time, id - ` - selectMessagesDueQuery = ` - SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE time <= ? AND published = 0 - ORDER BY time, id - ` - selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1` - updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` - selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` - selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` - selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` - - updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` - selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` - selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` - selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` - - selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` - updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` -) - -// Schema management queries -const ( - currentSchemaVersion = 12 - createSchemaVersionTableQuery = ` - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - ` - insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` - selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` - - // 0 -> 1 - migrate0To1AlterMessagesTableQuery = ` - BEGIN; - ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0); - ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT(''); - COMMIT; - ` - - // 1 -> 2 - migrate1To2AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1); - ` - - // 2 -> 3 - migrate2To3AlterMessagesTableQuery = ` - BEGIN; - ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT(''); - COMMIT; - ` - // 3 -> 4 - migrate3To4AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT(''); - ` - - // 4 -> 5 - migrate4To5AlterMessagesTableQuery = ` - BEGIN; - CREATE TABLE IF NOT EXISTS messages_new ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - time INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - attachment_owner TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid); - CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic); - INSERT - INTO messages_new ( - mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, - attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) - SELECT - id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, - attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published - FROM messages; - DROP TABLE messages; - ALTER TABLE messages_new RENAME TO messages; - COMMIT; - ` - - // 5 -> 6 - migrate5To6AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT(''); - ` - - // 6 -> 7 - migrate6To7AlterMessagesTableQuery = ` - ALTER TABLE messages RENAME COLUMN attachment_owner TO sender; - ` - - // 7 -> 8 - migrate7To8AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT(''); - ` - - // 8 -> 9 - migrate8To9AlterMessagesTableQuery = ` - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - ` - - // 9 -> 10 - migrate9To10AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0'); - CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); - CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); - CREATE INDEX IF NOT EXISTS idx_user ON messages (user); - CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); - ` - migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` - - // 10 -> 11 - migrate10To11AlterMessagesTableQuery = ` - CREATE TABLE IF NOT EXISTS stats ( - key TEXT PRIMARY KEY, - value INT - ); - INSERT INTO stats (key, value) VALUES ('messages', 0); - ` - - // 11 -> 12 - migrate11To12AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT(''); - ` -) - -var ( - migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ - 0: migrateFrom0, - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, - 6: migrateFrom6, - 7: migrateFrom7, - 8: migrateFrom8, - 9: migrateFrom9, - 10: migrateFrom10, - 11: migrateFrom11, - } -) - -type messageCache struct { - db *sql.DB - queue *util.BatchingQueue[*message] - nop bool +type MessageCache interface { + AddMessage(m *message) error + AddMessages(ms []*message) error + Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) + MessagesDue() ([]*message, error) + MessagesExpired() ([]string, error) + Message(id string) (*message, error) + MarkPublished(m *message) error + MessageCounts() (map[string]int, error) + Topics() (map[string]*topic, error) + DeleteMessages(ids ...string) error + ExpireMessages(topics ...string) error + AttachmentsExpired() ([]string, error) + MarkAttachmentsDeleted(ids ...string) error + AttachmentBytesUsedBySender(sender string) (int64, error) + AttachmentBytesUsedByUser(userID string) (int64, error) + UpdateStats(messages int64) error + Stats() (messages int64, err error) + DB() *sql.DB + Close() error } -// newSqliteCache creates a SQLite file-backed cache -func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { - db, err := sql.Open("sqlite3", filename) - if err != nil { - return nil, err - } - if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil { - return nil, err - } - var queue *util.BatchingQueue[*message] - if batchSize > 0 || batchTimeout > 0 { - queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) - } - cache := &messageCache{ - db: db, - queue: queue, - nop: nop, - } - go cache.processMessageBatches() - return cache, nil +type commonMessageCache struct { + db *sql.DB + queue *util.BatchingQueue[*message] + queries *messageCacheQueries } -// newMemCache creates an in-memory cache -func newMemCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false) -} +var _ MessageCache = (*commonMessageCache)(nil) -// newNopCache creates an in-memory cache that discards all messages; -// it is always empty and can be used if caching is entirely disabled -func newNopCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true) -} +type messageCacheQueries struct { + insertMessage string + deleteMessage string + updateMessagesForTopicExpiry string + selectRowIDFromMessageID string // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics + selectMessagesByID string + selectMessagesSinceTime string + selectMessagesSinceTimeIncludeScheduled string + selectMessagesSinceID string + selectMessagesSinceIDIncludeScheduled string + selectMessagesDue string + selectMessagesExpired string + updateMessagePublished string + selectMessageCountPerTopic string + selectTopics string -// createMemoryFilename creates a unique memory filename to use for the SQLite backend. -// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory -// sql database, so if the stdlib's sql engine happens to open another connection and -// you've only specified ":memory:", that connection will see a brand new database. -// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared"). -// Every connection to this string will point to the same in-memory database." -func createMemoryFilename() string { - return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) + updateAttachmentDeleted string + selectAttachmentsExpired string + selectAttachmentsSizeBySender string + selectAttachmentsSizeByUserID string + + selectStats string + updateStats string } // AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. // The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. -func (c *messageCache) AddMessage(m *message) error { +func (c *commonMessageCache) AddMessage(m *message) error { if c.queue != nil { c.queue.Enqueue(m) return nil } - return c.addMessages([]*message{m}) + return c.AddMessages([]*message{m}) } -// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until +// AddMessages synchronously stores a match of messages. If the database is locked, the transaction waits until // SQLite's busy_timeout is exceeded before erroring out. -func (c *messageCache) addMessages(ms []*message) error { - if c.nop { - return nil - } +func (c *commonMessageCache) AddMessages(ms []*message) error { if len(ms) == 0 { return nil } @@ -339,7 +88,7 @@ func (c *messageCache) addMessages(ms []*message) error { return err } defer tx.Rollback() - stmt, err := tx.Prepare(insertMessageQuery) + stmt, err := tx.Prepare(c.queries.insertMessage) if err != nil { return err } @@ -351,7 +100,8 @@ func (c *messageCache) addMessages(ms []*message) error { published := m.Time <= time.Now().Unix() tags := strings.Join(m.Tags, ",") var attachmentName, attachmentType, attachmentURL string - var attachmentSize, attachmentExpires, attachmentDeleted int64 + var attachmentSize, attachmentExpires int64 + var attachmentDeleted bool if m.Attachment != nil { attachmentName = m.Attachment.Name attachmentType = m.Attachment.Type @@ -388,7 +138,7 @@ func (c *messageCache) addMessages(ms []*message) error { attachmentSize, attachmentExpires, attachmentURL, - attachmentDeleted, // Always zero + attachmentDeleted, // Always false sender, m.User, m.ContentType, @@ -407,7 +157,7 @@ func (c *messageCache) addMessages(ms []*message) error { return nil } -func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { +func (c *commonMessageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { if since.IsNone() { return make([]*message, 0), nil } else if since.IsID() { @@ -416,13 +166,13 @@ func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) return c.messagesSinceTime(topic, since, scheduled) } -func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) { +func (c *commonMessageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) { var rows *sql.Rows var err error if scheduled { - rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) + rows, err = c.db.Query(c.queries.selectMessagesSinceTimeIncludeScheduled, topic, since.Time().Unix()) } else { - rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) + rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix()) } if err != nil { return nil, err @@ -430,8 +180,8 @@ func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, schedu return readMessages(rows) } -func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { - idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID()) +func (c *commonMessageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { + idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID()) if err != nil { return nil, err } @@ -446,9 +196,9 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule idrows.Close() var rows *sql.Rows if scheduled { - rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID) + rows, err = c.db.Query(c.queries.selectMessagesSinceIDIncludeScheduled, topic, rowID) } else { - rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID) + rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID) } if err != nil { return nil, err @@ -456,8 +206,8 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule return readMessages(rows) } -func (c *messageCache) MessagesDue() ([]*message, error) { - rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix()) +func (c *commonMessageCache) MessagesDue() ([]*message, error) { + rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix()) if err != nil { return nil, err } @@ -465,8 +215,8 @@ func (c *messageCache) MessagesDue() ([]*message, error) { } // MessagesExpired returns a list of IDs for messages that have expires (should be deleted) -func (c *messageCache) MessagesExpired() ([]string, error) { - rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix()) +func (c *commonMessageCache) MessagesExpired() ([]string, error) { + rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix()) if err != nil { return nil, err } @@ -485,25 +235,24 @@ func (c *messageCache) MessagesExpired() ([]string, error) { return ids, nil } -func (c *messageCache) Message(id string) (*message, error) { - rows, err := c.db.Query(selectMessagesByIDQuery, id) +func (c *commonMessageCache) Message(id string) (*message, error) { + rows, err := c.db.Query(c.queries.selectMessagesByID, id) if err != nil { return nil, err - } - if !rows.Next() { + } else if !rows.Next() { return nil, errMessageNotFound } defer rows.Close() return readMessage(rows) } -func (c *messageCache) MarkPublished(m *message) error { - _, err := c.db.Exec(updateMessagePublishedQuery, m.ID) +func (c *commonMessageCache) MarkPublished(m *message) error { + _, err := c.db.Exec(c.queries.updateMessagePublished, m.ID) return err } -func (c *messageCache) MessageCounts() (map[string]int, error) { - rows, err := c.db.Query(selectMessageCountPerTopicQuery) +func (c *commonMessageCache) MessageCounts() (map[string]int, error) { + rows, err := c.db.Query(c.queries.selectMessageCountPerTopic) if err != nil { return nil, err } @@ -522,8 +271,8 @@ func (c *messageCache) MessageCounts() (map[string]int, error) { return counts, nil } -func (c *messageCache) Topics() (map[string]*topic, error) { - rows, err := c.db.Query(selectTopicsQuery) +func (c *commonMessageCache) Topics() (map[string]*topic, error) { + rows, err := c.db.Query(c.queries.selectTopics) if err != nil { return nil, err } @@ -542,36 +291,36 @@ func (c *messageCache) Topics() (map[string]*topic, error) { return topics, nil } -func (c *messageCache) DeleteMessages(ids ...string) error { +func (c *commonMessageCache) DeleteMessages(ids ...string) error { tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, id := range ids { - if _, err := tx.Exec(deleteMessageQuery, id); err != nil { + if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil { return err } } return tx.Commit() } -func (c *messageCache) ExpireMessages(topics ...string) error { +func (c *commonMessageCache) ExpireMessages(topics ...string) error { tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, t := range topics { - if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil { + if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil { return err } } return tx.Commit() } -func (c *messageCache) AttachmentsExpired() ([]string, error) { - rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) +func (c *commonMessageCache) AttachmentsExpired() ([]string, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) if err != nil { return nil, err } @@ -590,37 +339,37 @@ func (c *messageCache) AttachmentsExpired() ([]string, error) { return ids, nil } -func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error { +func (c *commonMessageCache) MarkAttachmentsDeleted(ids ...string) error { tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, id := range ids { - if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil { + if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil { return err } } return tx.Commit() } -func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix()) +func (c *commonMessageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) if err != nil { return 0, err } return c.readAttachmentBytesUsed(rows) } -func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix()) +func (c *commonMessageCache) AttachmentBytesUsedByUser(userID string) (int64, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) if err != nil { return 0, err } return c.readAttachmentBytesUsed(rows) } -func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { +func (c *commonMessageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { defer rows.Close() var size int64 if !rows.Next() { @@ -634,17 +383,45 @@ func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { return size, nil } -func (c *messageCache) processMessageBatches() { +func (c *commonMessageCache) processMessageBatches() { if c.queue == nil { return } for messages := range c.queue.Dequeue() { - if err := c.addMessages(messages); err != nil { + if err := c.AddMessages(messages); err != nil { log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch") } } } +func (c *commonMessageCache) UpdateStats(messages int64) error { + _, err := c.db.Exec(c.queries.updateStats, messages) + return err +} + +func (c *commonMessageCache) Stats() (messages int64, err error) { + rows, err := c.db.Query(c.queries.selectStats) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + if err := rows.Scan(&messages); err != nil { + return 0, err + } + return messages, nil +} + +func (c *commonMessageCache) DB() *sql.DB { + return c.db +} + +func (c *commonMessageCache) Close() error { + return c.db.Close() +} + func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) @@ -734,239 +511,3 @@ func readMessage(rows *sql.Rows) (*message, error) { Encoding: encoding, }, nil } - -func (c *messageCache) UpdateStats(messages int64) error { - _, err := c.db.Exec(updateStatsQuery, messages) - return err -} - -func (c *messageCache) Stats() (messages int64, err error) { - rows, err := c.db.Query(selectStatsQuery) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - if err := rows.Scan(&messages); err != nil { - return 0, err - } - return messages, nil -} - -func (c *messageCache) Close() error { - return c.db.Close() -} - -func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error { - // Run startup queries - if startupQueries != "" { - if _, err := db.Exec(startupQueries); err != nil { - return err - } - } - - // If 'messages' table does not exist, this must be a new database - rowsMC, err := db.Query(selectMessagesCountQuery) - if err != nil { - return setupNewCacheDB(db) - } - rowsMC.Close() - - // If 'messages' table exists, check 'schemaVersion' table - schemaVersion := 0 - rowsSV, err := db.Query(selectSchemaVersionQuery) - if err == nil { - defer rowsSV.Close() - if !rowsSV.Next() { - return errors.New("cannot determine schema version: cache file may be corrupt") - } - if err := rowsSV.Scan(&schemaVersion); err != nil { - return err - } - rowsSV.Close() - } - - // Do migrations - if schemaVersion == currentSchemaVersion { - return nil - } else if schemaVersion > currentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) - } - for i := schemaVersion; i < currentSchemaVersion; i++ { - fn, ok := migrations[i] - if !ok { - return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) - } else if err := fn(db, cacheDuration); err != nil { - return err - } - } - return nil -} - -func setupNewCacheDB(db *sql.DB) error { - if _, err := db.Exec(createMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil -} - -func migrateFrom0(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1") - if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, 1); err != nil { - return err - } - return nil -} - -func migrateFrom1(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2") - if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 2); err != nil { - return err - } - return nil -} - -func migrateFrom2(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3") - if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 3); err != nil { - return err - } - return nil -} - -func migrateFrom3(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4") - if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 4); err != nil { - return err - } - return nil -} - -func migrateFrom4(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5") - if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 5); err != nil { - return err - } - return nil -} - -func migrateFrom5(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6") - if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 6); err != nil { - return err - } - return nil -} - -func migrateFrom6(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7") - if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 7); err != nil { - return err - } - return nil -} - -func migrateFrom7(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8") - if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 8); err != nil { - return err - } - return nil -} - -func migrateFrom8(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9") - if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 9); err != nil { - return err - } - return nil -} - -func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 10); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom10(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 11); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom11(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 12); err != nil { - return err - } - return tx.Commit() -} diff --git a/server/message_cache_pg.go b/server/message_cache_pg.go new file mode 100644 index 00000000..02314928 --- /dev/null +++ b/server/message_cache_pg.go @@ -0,0 +1,179 @@ +package server + +import ( + "database/sql" + _ "github.com/lib/pq" // PostgreSQL driver + "heckel.io/ntfy/v2/util" + "time" +) + +// Messages cache +const ( + pgCreateMessagesTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id SERIAL PRIMARY KEY, + mid TEXT NOT NULL, + time INT NOT NULL, + expires INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_deleted BOOLEAN NOT NULL, + sender TEXT NOT NULL, + "user" TEXT NOT NULL, + content_type TEXT NOT NULL, + encoding TEXT NOT NULL, + published BOOLEAN NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); + CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); + CREATE INDEX IF NOT EXISTS idx_user ON messages ("user"); + CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + COMMIT; + ` + + pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages` +) + +var ( + pgMessageCacheQueries = &messageCacheQueries{ + insertMessage: ` + INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, "user", content_type, encoding, published) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22) + `, + deleteMessage: `DELETE FROM messages WHERE mid = $1`, + updateMessagesForTopicExpiry: `UPDATE messages SET expires = $1 WHERE topic = $2`, + selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = $1`, // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics + selectMessagesByID: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE mid = $1 + `, + selectMessagesSinceTime: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE topic = $1 AND time >= $2 AND published = TRUE + ORDER BY time, id + `, + selectMessagesSinceTimeIncludeScheduled: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE topic = $1 AND time >= $2 + ORDER BY time, id + `, + selectMessagesSinceID: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE topic = $1 AND id > $2 AND published = TRUE + ORDER BY time, id + `, + selectMessagesSinceIDIncludeScheduled: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE topic = $1 AND (id > $2 OR published = FALSE) + ORDER BY time, id + `, + selectMessagesDue: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, "user", content_type, encoding + FROM messages + WHERE time <= $1 AND published = FALSE + ORDER BY time, id + `, + selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= $1 AND published = TRUE`, + updateMessagePublished: `UPDATE messages SET published = TRUE WHERE mid = $1`, + selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`, + selectTopics: `SELECT topic FROM messages GROUP BY topic`, + + updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = TRUE WHERE mid = $1`, + selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`, + selectAttachmentsSizeBySender: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = '' AND sender = $1 AND attachment_expires >= $2`, + selectAttachmentsSizeByUserID: `SELECT COALESCE(SUM(attachment_size), 0) FROM messages WHERE "user" = $1 AND attachment_expires >= $2`, + + selectStats: `SELECT value FROM stats WHERE key = 'messages'`, + updateStats: `UPDATE stats SET value = $1 WHERE key = 'messages'`, + } +) + +type pgMessageCache struct { + *commonMessageCache +} + +var _ MessageCache = (*pgMessageCache)(nil) + +func newPgMessageCache(connectionString, startupQueries string, batchSize int, batchTimeout time.Duration) (*pgMessageCache, error) { + db, err := sql.Open("postgres", connectionString) + if err != nil { + return nil, err + } + if err := setupPgMessagesDB(db, startupQueries); err != nil { + return nil, err + } + var queue *util.BatchingQueue[*message] + if batchSize > 0 || batchTimeout > 0 { + queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) + } + cache := &pgMessageCache{ + commonMessageCache: &commonMessageCache{ + db: db, + queue: queue, + queries: pgMessageCacheQueries, + }, + } + go cache.processMessageBatches() + return cache, nil +} + +func setupPgMessagesDB(db *sql.DB, startupQueries string) error { + // Run startup queries + if startupQueries != "" { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + } + + // If 'messages' table does not exist, this must be a new database + rowsMC, err := db.Query(pgSelectMessagesCountQuery) + if err != nil { + return setupNewPgCacheDB(db) + } + rowsMC.Close() + + return nil + + // FIXME schema migration +} + +func setupNewPgCacheDB(db *sql.DB) error { + if _, err := db.Exec(pgCreateMessagesTableQuery); err != nil { + return err + } + /* + // FIXME + if _, err := db.Exec(pgCreateSchemaVersionTableQuery); err != nil { + return err + } + if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { + return err + } + */ + return nil +} diff --git a/server/message_cache_sqlite.go b/server/message_cache_sqlite.go new file mode 100644 index 00000000..f9c2acbf --- /dev/null +++ b/server/message_cache_sqlite.go @@ -0,0 +1,542 @@ +package server + +import ( + "database/sql" + "errors" + "fmt" + "time" + + _ "github.com/mattn/go-sqlite3" // SQLite driver + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" +) + +var ( + errUnexpectedMessageType = errors.New("unexpected message type") + errMessageNotFound = errors.New("message not found") + errNoRows = errors.New("no rows found") +) + +// Messages cache +const ( + createMessagesTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + time INT NOT NULL, + expires INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_deleted INT NOT NULL, + sender TEXT NOT NULL, + user TEXT NOT NULL, + content_type TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); + CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); + CREATE INDEX IF NOT EXISTS idx_user ON messages (user); + CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + COMMIT; + ` +) + +var ( + sqliteMessageCacheQueries = &messageCacheQueries{ + insertMessage: ` + INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + deleteMessage: `DELETE FROM messages WHERE mid = ?`, + updateMessagesForTopicExpiry: `UPDATE messages SET expires = ? WHERE topic = ?`, + selectRowIDFromMessageID: `SELECT id FROM messages WHERE mid = ?`, // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics + selectMessagesByID: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE mid = ? + `, + selectMessagesSinceTime: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND time >= ? AND published = 1 + ORDER BY time, id + `, + selectMessagesSinceTimeIncludeScheduled: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND time >= ? + ORDER BY time, id + `, + selectMessagesSinceID: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND id > ? AND published = 1 + ORDER BY time, id + `, + selectMessagesSinceIDIncludeScheduled: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND (id > ? OR published = 0) + ORDER BY time, id + `, + selectMessagesDue: ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE time <= ? AND published = 0 + ORDER BY time, id + `, + selectMessagesExpired: `SELECT mid FROM messages WHERE expires <= ? AND published = 1`, + updateMessagePublished: `UPDATE messages SET published = 1 WHERE mid = ?`, + selectMessageCountPerTopic: `SELECT topic, COUNT(*) FROM messages GROUP BY topic`, + selectTopics: `SELECT topic FROM messages GROUP BY topic`, + + updateAttachmentDeleted: `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`, + selectAttachmentsExpired: `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`, + selectAttachmentsSizeBySender: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`, + selectAttachmentsSizeByUserID: `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`, + + selectStats: `SELECT value FROM stats WHERE key = 'messages'`, + updateStats: `UPDATE stats SET value = ? WHERE key = 'messages'`, + } +) + +// Schema management queries +const ( + currentSchemaVersion = 12 + createSchemaVersionTableQuery = ` + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + ` + insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` + selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` + selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` + + // 0 -> 1 + migrate0To1AlterMessagesTableQuery = ` + BEGIN; + ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0); + ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT(''); + COMMIT; + ` + + // 1 -> 2 + migrate1To2AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1); + ` + + // 2 -> 3 + migrate2To3AlterMessagesTableQuery = ` + BEGIN; + ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT(''); + COMMIT; + ` + // 3 -> 4 + migrate3To4AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT(''); + ` + + // 4 -> 5 + migrate4To5AlterMessagesTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + time INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_owner TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid); + CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic); + INSERT + INTO messages_new ( + mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, + attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) + SELECT + id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, + attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published + FROM messages; + DROP TABLE messages; + ALTER TABLE messages_new RENAME TO messages; + COMMIT; + ` + + // 5 -> 6 + migrate5To6AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT(''); + ` + + // 6 -> 7 + migrate6To7AlterMessagesTableQuery = ` + ALTER TABLE messages RENAME COLUMN attachment_owner TO sender; + ` + + // 7 -> 8 + migrate7To8AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT(''); + ` + + // 8 -> 9 + migrate8To9AlterMessagesTableQuery = ` + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + ` + + // 9 -> 10 + migrate9To10AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0'); + CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); + CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); + CREATE INDEX IF NOT EXISTS idx_user ON messages (user); + CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + ` + migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` + + // 10 -> 11 + migrate10To11AlterMessagesTableQuery = ` + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + ` + + // 11 -> 12 + migrate11To12AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT(''); + ` +) + +var ( + migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ + 0: migrateFrom0, + 1: migrateFrom1, + 2: migrateFrom2, + 3: migrateFrom3, + 4: migrateFrom4, + 5: migrateFrom5, + 6: migrateFrom6, + 7: migrateFrom7, + 8: migrateFrom8, + 9: migrateFrom9, + 10: migrateFrom10, + 11: migrateFrom11, + } +) + +type sqliteMessageCache struct { + *commonMessageCache + nop bool +} + +var _ MessageCache = (*sqliteMessageCache)(nil) + +// newSqliteMessageCache creates a SQLite file-backed cache +func newSqliteMessageCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*sqliteMessageCache, error) { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return nil, err + } + if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil { + return nil, err + } + var queue *util.BatchingQueue[*message] + if batchSize > 0 || batchTimeout > 0 { + queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) + } + cache := &sqliteMessageCache{ + commonMessageCache: &commonMessageCache{ + db: db, + queue: queue, + queries: sqliteMessageCacheQueries, + }, + nop: nop, + } + go cache.processMessageBatches() + return cache, nil +} + +// newMemCache creates an in-memory cache +func newMemCache() (*sqliteMessageCache, error) { + return newSqliteMessageCache(createMemoryFilename(), "", 0, 0, 0, false) +} + +// newNopCache creates an in-memory cache that discards all messages; +// it is always empty and can be used if caching is entirely disabled +func newNopCache() (*sqliteMessageCache, error) { + return newSqliteMessageCache(createMemoryFilename(), "", 0, 0, 0, true) +} + +// createMemoryFilename creates a unique memory filename to use for the SQLite backend. +// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory +// sql database, so if the stdlib's sql engine happens to open another connection and +// you've only specified ":memory:", that connection will see a brand new database. +// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared"). +// Every connection to this string will point to the same in-memory database." +func createMemoryFilename() string { + return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) +} + +// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. +// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. +func (c *sqliteMessageCache) AddMessage(m *message) error { + if c.nop { + return nil + } + return c.commonMessageCache.AddMessage(m) +} + +func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error { + // Run startup queries + if startupQueries != "" { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + } + + // If 'messages' table does not exist, this must be a new database + rowsMC, err := db.Query(selectMessagesCountQuery) + if err != nil { + return setupNewCacheDB(db) + } + rowsMC.Close() + + // If 'messages' table exists, check 'schemaVersion' table + schemaVersion := 0 + rowsSV, err := db.Query(selectSchemaVersionQuery) + if err == nil { + defer rowsSV.Close() + if !rowsSV.Next() { + return errors.New("cannot determine schema version: cache file may be corrupt") + } + if err := rowsSV.Scan(&schemaVersion); err != nil { + return err + } + rowsSV.Close() + } + + // Do migrations + if schemaVersion == currentSchemaVersion { + return nil + } else if schemaVersion > currentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) + } + for i := schemaVersion; i < currentSchemaVersion; i++ { + fn, ok := migrations[i] + if !ok { + return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) + } else if err := fn(db, cacheDuration); err != nil { + return err + } + } + return nil +} + +func setupNewCacheDB(db *sql.DB) error { + if _, err := db.Exec(createMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { + return err + } + if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { + return err + } + return nil +} + +func migrateFrom0(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1") + if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { + return err + } + if _, err := db.Exec(insertSchemaVersion, 1); err != nil { + return err + } + return nil +} + +func migrateFrom1(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2") + if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 2); err != nil { + return err + } + return nil +} + +func migrateFrom2(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3") + if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 3); err != nil { + return err + } + return nil +} + +func migrateFrom3(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4") + if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 4); err != nil { + return err + } + return nil +} + +func migrateFrom4(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5") + if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 5); err != nil { + return err + } + return nil +} + +func migrateFrom5(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6") + if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 6); err != nil { + return err + } + return nil +} + +func migrateFrom6(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7") + if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 7); err != nil { + return err + } + return nil +} + +func migrateFrom7(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8") + if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 8); err != nil { + return err + } + return nil +} + +func migrateFrom8(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9") + if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(updateSchemaVersion, 9); err != nil { + return err + } + return nil +} + +func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 10); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom10(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 11); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom11(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 12); err != nil { + return err + } + return tx.Commit() +} diff --git a/server/message_cache_sqlite_test.go b/server/message_cache_sqlite_test.go new file mode 100644 index 00000000..1f4bc9a2 --- /dev/null +++ b/server/message_cache_sqlite_test.go @@ -0,0 +1,254 @@ +package server + +import ( + "database/sql" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestSqliteCache_Migration_From0(t *testing.T) { + filename := newSqliteTestCacheFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 0" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id VARCHAR(20) PRIMARY KEY, + time INT NOT NULL, + topic VARCHAR(64) NOT NULL, + message VARCHAR(1024) NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + COMMIT; + `) + require.Nil(t, err) + + // Insert a bunch of messages + for i := 0; i < 10; i++ { + _, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`, + fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i)) + require.Nil(t, err) + } + require.Nil(t, db.Close()) + + // Create cache to trigger migration + c := newSqliteTestCacheFromFile(t, filename, "") + checkSchemaVersion(t, c.db) + + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + require.Equal(t, "some message 5", messages[5].Message) + require.Equal(t, "", messages[5].Title) + require.Nil(t, messages[5].Tags) + require.Equal(t, 0, messages[5].Priority) +} + +func TestSqliteCache_Migration_From1(t *testing.T) { + filename := newSqliteTestCacheFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 1" schema + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS messages ( + id VARCHAR(20) PRIMARY KEY, + time INT NOT NULL, + topic VARCHAR(64) NOT NULL, + message VARCHAR(512) NOT NULL, + title VARCHAR(256) NOT NULL, + priority INT NOT NULL, + tags VARCHAR(256) NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schemaVersion (id, version) VALUES (1, 1); + `) + require.Nil(t, err) + + // Insert a bunch of messages + for i := 0; i < 10; i++ { + _, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`, + fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "") + require.Nil(t, err) + } + require.Nil(t, db.Close()) + + // Create cache to trigger migration + c := newSqliteTestCacheFromFile(t, filename, "") + checkSchemaVersion(t, c.db) + + // Add delayed message + delayedMessage := newDefaultMessage("mytopic", "some delayed message") + delayedMessage.Time = time.Now().Add(time.Minute).Unix() + require.Nil(t, c.AddMessage(delayedMessage)) + + // 10, not 11! + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + + // 11! + messages, err = c.Messages("mytopic", sinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 11, len(messages)) +} + +func TestSqliteCache_Migration_From9(t *testing.T) { + // This primarily tests the awkward migration that introduces the "expires" column. + // The migration logic has to update the column, using the existing "cache-duration" value. + + filename := newSqliteTestCacheFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 8" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + time INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + sender TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schemaVersion (id, version) VALUES (1, 9); + COMMIT; + `) + require.Nil(t, err) + + // Insert a bunch of messages + insertQuery := ` + INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + for i := 0; i < 10; i++ { + _, err = db.Exec( + insertQuery, + fmt.Sprintf("abcd%d", i), + time.Now().Unix(), + "mytopic", + fmt.Sprintf("some message %d", i), + "", // title + 0, // priority + "", // tags + "", // click + "", // icon + "", // actions + "", // attachment_name + "", // attachment_type + 0, // attachment_size + 0, // attachment_type + "", // attachment_url + "9.9.9.9", // sender + "", // encoding + 1, // published + ) + require.Nil(t, err) + } + + // Create cache to trigger migration + cacheDuration := 17 * time.Hour + c, err := newSqliteMessageCache(filename, "", cacheDuration, 0, 0, false) + require.Nil(t, err) + checkSchemaVersion(t, c.db) + + // Check version + rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`) + require.Nil(t, err) + require.True(t, rows.Next()) + var version int + require.Nil(t, rows.Scan(&version)) + require.Equal(t, currentSchemaVersion, version) + + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + for _, m := range messages { + require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix()) + require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix()) + } +} + +func TestSqliteCache_StartupQueries_WAL(t *testing.T) { + filename := newSqliteTestCacheFile(t) + startupQueries := `pragma journal_mode = WAL; +pragma synchronous = normal; +pragma temp_store = memory;` + db, err := newSqliteMessageCache(filename, startupQueries, time.Hour, 0, 0, false) + require.Nil(t, err) + require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) + require.FileExists(t, filename) + require.FileExists(t, filename+"-wal") + require.FileExists(t, filename+"-shm") +} + +func TestSqliteCache_StartupQueries_None(t *testing.T) { + filename := newSqliteTestCacheFile(t) + startupQueries := "" + db, err := newSqliteMessageCache(filename, startupQueries, time.Hour, 0, 0, false) + require.Nil(t, err) + require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) + require.FileExists(t, filename) + require.NoFileExists(t, filename+"-wal") + require.NoFileExists(t, filename+"-shm") +} + +func TestSqliteCache_StartupQueries_Fail(t *testing.T) { + filename := newSqliteTestCacheFile(t) + startupQueries := `xx error` + _, err := newSqliteMessageCache(filename, startupQueries, time.Hour, 0, 0, false) + require.Error(t, err) +} + +func TestMemCache_NopCache(t *testing.T) { + c, _ := newNopCache() + assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) + + messages, err := c.Messages("mytopic", sinceAllMessages, false) + assert.Nil(t, err) + assert.Empty(t, messages) + + topics, err := c.Topics() + assert.Nil(t, err) + assert.Empty(t, topics) +} +func checkSchemaVersion(t *testing.T, db *sql.DB) { + rows, err := db.Query(`SELECT version FROM schemaVersion`) + require.Nil(t, err) + require.True(t, rows.Next()) + + var schemaVersion int + require.Nil(t, rows.Scan(&schemaVersion)) + require.Equal(t, currentSchemaVersion, schemaVersion) + require.Nil(t, rows.Close()) +} diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 79b7fc54..43e5bc0c 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -2,692 +2,392 @@ package server import ( "database/sql" - "fmt" "net/netip" + "os" "path/filepath" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestSqliteCache_Messages(t *testing.T) { - testCacheMessages(t, newSqliteTestCache(t)) -} +func TestCache_Messages(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m1 := newDefaultMessage("mytopic", "my message") + m1.Time = 1 -func TestMemCache_Messages(t *testing.T) { - testCacheMessages(t, newMemTestCache(t)) -} + m2 := newDefaultMessage("mytopic", "my other message") + m2.Time = 2 -func testCacheMessages(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "my message") - m1.Time = 1 + require.Nil(t, c.AddMessage(m1)) + require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message"))) + require.Nil(t, c.AddMessage(m2)) - m2 := newDefaultMessage("mytopic", "my other message") - m2.Time = 2 + // Adding invalid + require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added! + require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message"))) - require.Nil(t, c.AddMessage(m2)) - - // Adding invalid - require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added! - require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! - - // mytopic: count - counts, err := c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - - // mytopic: since all - messages, _ := c.Messages("mytopic", sinceAllMessages, false) - require.Equal(t, 2, len(messages)) - require.Equal(t, "my message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - require.Equal(t, messageEvent, messages[0].Event) - require.Equal(t, "", messages[0].Title) - require.Equal(t, 0, messages[0].Priority) - require.Nil(t, messages[0].Tags) - require.Equal(t, "my other message", messages[1].Message) - - // mytopic: since none - messages, _ = c.Messages("mytopic", sinceNoMessages, false) - require.Empty(t, messages) - - // mytopic: since m1 (by ID) - messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, m2.ID, messages[0].ID) - require.Equal(t, "my other message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - - // mytopic: since 2 - messages, _ = c.Messages("mytopic", newSinceTime(2), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) - - // example: count - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["example"]) - - // example: since all - messages, _ = c.Messages("example", sinceAllMessages, false) - require.Equal(t, "my example message", messages[0].Message) - - // non-existing: count - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 0, counts["doesnotexist"]) - - // non-existing: since all - messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) - require.Empty(t, messages) -} - -func TestSqliteCache_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newMemTestCache(t)) -} - -func testCacheMessagesScheduled(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "message 1") - m2 := newDefaultMessage("mytopic", "message 2") - m2.Time = time.Now().Add(time.Hour).Unix() - m3 := newDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! - m4 := newDefaultMessage("mytopic2", "message 4") - m4.Time = time.Now().Add(time.Minute).Unix() - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - - messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled - require.Equal(t, 1, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - - messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) // Order! - require.Equal(t, "message 2", messages[2].Message) - - messages, _ = c.MessagesDue() - require.Empty(t, messages) -} - -func TestSqliteCache_Topics(t *testing.T) { - testCacheTopics(t, newSqliteTestCache(t)) -} - -func TestMemCache_Topics(t *testing.T) { - testCacheTopics(t, newMemTestCache(t)) -} - -func testCacheTopics(t *testing.T, c *messageCache) { - require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3"))) - - topics, err := c.Topics() - if err != nil { - t.Fatal(err) - } - require.Equal(t, 2, len(topics)) - require.Equal(t, "topic1", topics["topic1"].ID) - require.Equal(t, "topic2", topics["topic2"].ID) -} - -func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t)) -} - -func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) { - m := newDefaultMessage("mytopic", "some message") - m.Tags = []string{"tag1", "tag2"} - m.Priority = 5 - m.Title = "some title" - require.Nil(t, c.AddMessage(m)) - - messages, _ := c.Messages("mytopic", sinceAllMessages, false) - require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) - require.Equal(t, 5, messages[0].Priority) - require.Equal(t, "some title", messages[0].Title) -} - -func TestSqliteCache_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newMemTestCache(t)) -} - -func testCacheMessagesSinceID(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "message 1") - m1.Time = 100 - m2 := newDefaultMessage("mytopic", "message 2") - m2.Time = 200 - m3 := newDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 - m4 := newDefaultMessage("mytopic", "message 4") - m4.Time = 400 - m5 := newDefaultMessage("mytopic", "message 5") - m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 - m6 := newDefaultMessage("mytopic", "message 6") - m6.Time = 600 - m7 := newDefaultMessage("mytopic", "message 7") - m7.Time = 700 - - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - require.Nil(t, c.AddMessage(m4)) - require.Nil(t, c.AddMessage(m5)) - require.Nil(t, c.AddMessage(m6)) - require.Nil(t, c.AddMessage(m7)) - - // Case 1: Since ID exists, exclude scheduled - messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false) - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! - require.Equal(t, "message 7", messages[2].Message) - - // Case 2: Since ID exists, include scheduled - messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true) - require.Equal(t, 5, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) - require.Equal(t, "message 7", messages[2].Message) - require.Equal(t, "message 5", messages[3].Message) // Order! - require.Equal(t, "message 3", messages[4].Message) // Order! - - // Case 3: Since ID does not exist (-> Return all messages), include scheduled - messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true) - require.Equal(t, 7, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 2", messages[1].Message) - require.Equal(t, "message 4", messages[2].Message) - require.Equal(t, "message 6", messages[3].Message) - require.Equal(t, "message 7", messages[4].Message) - require.Equal(t, "message 5", messages[5].Message) // Order! - require.Equal(t, "message 3", messages[6].Message) // Order! - - // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled - messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false) - require.Equal(t, 0, len(messages)) - - // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled - messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true) - require.Equal(t, 2, len(messages)) - require.Equal(t, "message 5", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) -} - -func TestSqliteCache_Prune(t *testing.T) { - testCachePrune(t, newSqliteTestCache(t)) -} - -func TestMemCache_Prune(t *testing.T) { - testCachePrune(t, newMemTestCache(t)) -} - -func testCachePrune(t *testing.T, c *messageCache) { - now := time.Now().Unix() - - m1 := newDefaultMessage("mytopic", "my message") - m1.Time = now - 10 - m1.Expires = now - 5 - - m2 := newDefaultMessage("mytopic", "my other message") - m2.Time = now - 5 - m2.Expires = now + 5 // In the future - - m3 := newDefaultMessage("another_topic", "and another one") - m3.Time = now - 12 - m3.Expires = now - 2 - - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - - counts, err := c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - require.Equal(t, 1, counts["another_topic"]) - - expiredMessageIDs, err := c.MessagesExpired() - require.Nil(t, err) - require.Nil(t, c.DeleteMessages(expiredMessageIDs...)) - - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["mytopic"]) - require.Equal(t, 0, counts["another_topic"]) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) -} - -func TestSqliteCache_Attachments(t *testing.T) { - testCacheAttachments(t, newSqliteTestCache(t)) -} - -func TestMemCache_Attachments(t *testing.T) { - testCacheAttachments(t, newMemTestCache(t)) -} - -func testCacheAttachments(t *testing.T, c *messageCache) { - expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired - m := newDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &attachment{ - Name: "flower.jpg", - Type: "image/jpeg", - Size: 5000, - Expires: expires1, - URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", - } - require.Nil(t, c.AddMessage(m)) - - expires2 := time.Now().Add(2 * time.Hour).Unix() // Future - m = newDefaultMessage("mytopic", "sending you a car") - m.ID = "m2" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: expires2, - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - expires3 := time.Now().Add(1 * time.Hour).Unix() // Future - m = newDefaultMessage("another-topic", "sending you another car") - m.ID = "m3" - m.User = "u_BAsbaAa" - m.Sender = netip.MustParseAddr("5.6.7.8") - m.Attachment = &attachment{ - Name: "another-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: expires3, - URL: "https://ntfy.sh/file/zakaDHFW.jpg", - } - require.Nil(t, c.AddMessage(m)) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - - require.Equal(t, "flower for you", messages[0].Message) - require.Equal(t, "flower.jpg", messages[0].Attachment.Name) - require.Equal(t, "image/jpeg", messages[0].Attachment.Type) - require.Equal(t, int64(5000), messages[0].Attachment.Size) - require.Equal(t, expires1, messages[0].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[0].Sender.String()) - - require.Equal(t, "sending you a car", messages[1].Message) - require.Equal(t, "car.jpg", messages[1].Attachment.Name) - require.Equal(t, "image/jpeg", messages[1].Attachment.Type) - require.Equal(t, int64(10000), messages[1].Attachment.Size) - require.Equal(t, expires2, messages[1].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[1].Sender.String()) - - size, err := c.AttachmentBytesUsedBySender("1.2.3.4") - require.Nil(t, err) - require.Equal(t, int64(10000), size) - - size, err = c.AttachmentBytesUsedBySender("5.6.7.8") - require.Nil(t, err) - require.Equal(t, int64(0), size) // Accounted to the user, not the IP! - - size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa") - require.Nil(t, err) - require.Equal(t, int64(20000), size) -} - -func TestSqliteCache_Attachments_Expired(t *testing.T) { - testCacheAttachmentsExpired(t, newSqliteTestCache(t)) -} - -func TestMemCache_Attachments_Expired(t *testing.T) { - testCacheAttachmentsExpired(t, newMemTestCache(t)) -} - -func testCacheAttachmentsExpired(t *testing.T, c *messageCache) { - m := newDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.Expires = time.Now().Add(time.Hour).Unix() - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic", "message with attachment") - m.ID = "m2" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: time.Now().Add(2 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic", "message with external attachment") - m.ID = "m3" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Expires: 0, // Unknown! - URL: "https://somedomain.com/car.jpg", - } - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic2", "message with expired attachment") - m.ID = "m4" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "expired-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: time.Now().Add(-1 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - ids, err := c.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 1, len(ids)) - require.Equal(t, "m4", ids[0]) -} - -func TestSqliteCache_Migration_From0(t *testing.T) { - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) - - // Create "version 0" schema - _, err = db.Exec(` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id VARCHAR(20) PRIMARY KEY, - time INT NOT NULL, - topic VARCHAR(64) NOT NULL, - message VARCHAR(1024) NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - COMMIT; - `) - require.Nil(t, err) - - // Insert a bunch of messages - for i := 0; i < 10; i++ { - _, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`, - fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i)) + // mytopic: count + counts, err := c.MessageCounts() require.Nil(t, err) - } - require.Nil(t, db.Close()) + require.Equal(t, 2, counts["mytopic"]) - // Create cache to trigger migration - c := newSqliteTestCacheFromFile(t, filename, "") - checkSchemaVersion(t, c.db) + // mytopic: since all + messages, _ := c.Messages("mytopic", sinceAllMessages, false) + require.Equal(t, 2, len(messages)) + require.Equal(t, "my message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) + require.Equal(t, messageEvent, messages[0].Event) + require.Equal(t, "", messages[0].Title) + require.Equal(t, 0, messages[0].Priority) + require.Nil(t, messages[0].Tags) + require.Equal(t, "my other message", messages[1].Message) - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - require.Equal(t, "some message 5", messages[5].Message) - require.Equal(t, "", messages[5].Title) - require.Nil(t, messages[5].Tags) - require.Equal(t, 0, messages[5].Priority) -} + // mytopic: since none + messages, _ = c.Messages("mytopic", sinceNoMessages, false) + require.Empty(t, messages) -func TestSqliteCache_Migration_From1(t *testing.T) { - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) + // mytopic: since m1 (by ID) + messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, m2.ID, messages[0].ID) + require.Equal(t, "my other message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) - // Create "version 1" schema - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS messages ( - id VARCHAR(20) PRIMARY KEY, - time INT NOT NULL, - topic VARCHAR(64) NOT NULL, - message VARCHAR(512) NOT NULL, - title VARCHAR(256) NOT NULL, - priority INT NOT NULL, - tags VARCHAR(256) NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO schemaVersion (id, version) VALUES (1, 1); - `) - require.Nil(t, err) + // mytopic: since 2 + messages, _ = c.Messages("mytopic", newSinceTime(2), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) - // Insert a bunch of messages - for i := 0; i < 10; i++ { - _, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`, - fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "") + // example: count + counts, err = c.MessageCounts() require.Nil(t, err) - } - require.Nil(t, db.Close()) + require.Equal(t, 1, counts["example"]) - // Create cache to trigger migration - c := newSqliteTestCacheFromFile(t, filename, "") - checkSchemaVersion(t, c.db) + // example: since all + messages, _ = c.Messages("example", sinceAllMessages, false) + require.Equal(t, "my example message", messages[0].Message) - // Add delayed message - delayedMessage := newDefaultMessage("mytopic", "some delayed message") - delayedMessage.Time = time.Now().Add(time.Minute).Unix() - require.Nil(t, c.AddMessage(delayedMessage)) - - // 10, not 11! - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - - // 11! - messages, err = c.Messages("mytopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 11, len(messages)) -} - -func TestSqliteCache_Migration_From9(t *testing.T) { - // This primarily tests the awkward migration that introduces the "expires" column. - // The migration logic has to update the column, using the existing "cache-duration" value. - - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) - - // Create "version 8" schema - _, err = db.Exec(` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - time INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - icon TEXT NOT NULL, - actions TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - sender TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO schemaVersion (id, version) VALUES (1, 9); - COMMIT; - `) - require.Nil(t, err) - - // Insert a bunch of messages - insertQuery := ` - INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - for i := 0; i < 10; i++ { - _, err = db.Exec( - insertQuery, - fmt.Sprintf("abcd%d", i), - time.Now().Unix(), - "mytopic", - fmt.Sprintf("some message %d", i), - "", // title - 0, // priority - "", // tags - "", // click - "", // icon - "", // actions - "", // attachment_name - "", // attachment_type - 0, // attachment_size - 0, // attachment_type - "", // attachment_url - "9.9.9.9", // sender - "", // encoding - 1, // published - ) + // non-existing: count + counts, err = c.MessageCounts() require.Nil(t, err) - } + require.Equal(t, 0, counts["doesnotexist"]) - // Create cache to trigger migration - cacheDuration := 17 * time.Hour - c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false) - require.Nil(t, err) - checkSchemaVersion(t, c.db) - - // Check version - rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`) - require.Nil(t, err) - require.True(t, rows.Next()) - var version int - require.Nil(t, rows.Scan(&version)) - require.Equal(t, currentSchemaVersion, version) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - for _, m := range messages { - require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix()) - require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix()) - } + // non-existing: since all + messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) + require.Empty(t, messages) + }) } -func TestSqliteCache_StartupQueries_WAL(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := `pragma journal_mode = WAL; -pragma synchronous = normal; -pragma temp_store = memory;` - db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Nil(t, err) - require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) - require.FileExists(t, filename) - require.FileExists(t, filename+"-wal") - require.FileExists(t, filename+"-shm") +func TestCache_MessagesScheduled(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m1 := newDefaultMessage("mytopic", "message 1") + m2 := newDefaultMessage("mytopic", "message 2") + m2.Time = time.Now().Add(time.Hour).Unix() + m3 := newDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! + m4 := newDefaultMessage("mytopic2", "message 4") + m4.Time = time.Now().Add(time.Minute).Unix() + require.Nil(t, c.AddMessage(m1)) + require.Nil(t, c.AddMessage(m2)) + require.Nil(t, c.AddMessage(m3)) + + messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled + require.Equal(t, 1, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + + messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) // Order! + require.Equal(t, "message 2", messages[2].Message) + + messages, _ = c.MessagesDue() + require.Empty(t, messages) + }) } -func TestSqliteCache_StartupQueries_None(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := "" - db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Nil(t, err) - require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) - require.FileExists(t, filename) - require.NoFileExists(t, filename+"-wal") - require.NoFileExists(t, filename+"-shm") +func TestCache_Topics(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message"))) + require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1"))) + require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2"))) + require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3"))) + + topics, err := c.Topics() + if err != nil { + t.Fatal(err) + } + require.Equal(t, 2, len(topics)) + require.Equal(t, "topic1", topics["topic1"].ID) + require.Equal(t, "topic2", topics["topic2"].ID) + }) } -func TestSqliteCache_StartupQueries_Fail(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := `xx error` - _, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Error(t, err) +func TestCache_MessagesTagsPrioAndTitle(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m := newDefaultMessage("mytopic", "some message") + m.Tags = []string{"tag1", "tag2"} + m.Priority = 5 + m.Title = "some title" + require.Nil(t, c.AddMessage(m)) + + messages, _ := c.Messages("mytopic", sinceAllMessages, false) + require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) + require.Equal(t, 5, messages[0].Priority) + require.Equal(t, "some title", messages[0].Title) + }) } -func TestSqliteCache_Sender(t *testing.T) { - testSender(t, newSqliteTestCache(t)) +func TestCache_MessagesSinceID(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m1 := newDefaultMessage("mytopic", "message 1") + m1.Time = 100 + m2 := newDefaultMessage("mytopic", "message 2") + m2.Time = 200 + m3 := newDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 + m4 := newDefaultMessage("mytopic", "message 4") + m4.Time = 400 + m5 := newDefaultMessage("mytopic", "message 5") + m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 + m6 := newDefaultMessage("mytopic", "message 6") + m6.Time = 600 + m7 := newDefaultMessage("mytopic", "message 7") + m7.Time = 700 + + require.Nil(t, c.AddMessage(m1)) + require.Nil(t, c.AddMessage(m2)) + require.Nil(t, c.AddMessage(m3)) + require.Nil(t, c.AddMessage(m4)) + require.Nil(t, c.AddMessage(m5)) + require.Nil(t, c.AddMessage(m6)) + require.Nil(t, c.AddMessage(m7)) + + // Case 1: Since ID exists, exclude scheduled + messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false) + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! + require.Equal(t, "message 7", messages[2].Message) + + // Case 2: Since ID exists, include scheduled + messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true) + require.Equal(t, 5, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) + require.Equal(t, "message 7", messages[2].Message) + require.Equal(t, "message 5", messages[3].Message) // Order! + require.Equal(t, "message 3", messages[4].Message) // Order! + + // Case 3: Since ID does not exist (-> Return all messages), include scheduled + messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true) + require.Equal(t, 7, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 2", messages[1].Message) + require.Equal(t, "message 4", messages[2].Message) + require.Equal(t, "message 6", messages[3].Message) + require.Equal(t, "message 7", messages[4].Message) + require.Equal(t, "message 5", messages[5].Message) // Order! + require.Equal(t, "message 3", messages[6].Message) // Order! + + // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled + messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false) + require.Equal(t, 0, len(messages)) + + // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled + messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true) + require.Equal(t, 2, len(messages)) + require.Equal(t, "message 5", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) + }) } -func TestMemCache_Sender(t *testing.T) { - testSender(t, newMemTestCache(t)) +func TestCache_Prune(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + now := time.Now().Unix() + + m1 := newDefaultMessage("mytopic", "my message") + m1.Time = now - 10 + m1.Expires = now - 5 + + m2 := newDefaultMessage("mytopic", "my other message") + m2.Time = now - 5 + m2.Expires = now + 5 // In the future + + m3 := newDefaultMessage("another_topic", "and another one") + m3.Time = now - 12 + m3.Expires = now - 2 + + require.Nil(t, c.AddMessage(m1)) + require.Nil(t, c.AddMessage(m2)) + require.Nil(t, c.AddMessage(m3)) + + counts, err := c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + require.Equal(t, 1, counts["another_topic"]) + + expiredMessageIDs, err := c.MessagesExpired() + require.Nil(t, err) + require.Nil(t, c.DeleteMessages(expiredMessageIDs...)) + + counts, err = c.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["mytopic"]) + require.Equal(t, 0, counts["another_topic"]) + + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + }) } -func testSender(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "mymessage") - m1.Sender = netip.MustParseAddr("1.2.3.4") - require.Nil(t, c.AddMessage(m1)) +func TestCache_Attachments(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired + m := newDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &attachment{ + Name: "flower.jpg", + Type: "image/jpeg", + Size: 5000, + Expires: expires1, + URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", + } + require.Nil(t, c.AddMessage(m)) - m2 := newDefaultMessage("mytopic", "mymessage without sender") - require.Nil(t, c.AddMessage(m2)) + expires2 := time.Now().Add(2 * time.Hour).Unix() // Future + m = newDefaultMessage("mytopic", "sending you a car") + m.ID = "m2" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: expires2, + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, c.AddMessage(m)) - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) - require.Equal(t, messages[1].Sender, netip.Addr{}) + expires3 := time.Now().Add(1 * time.Hour).Unix() // Future + m = newDefaultMessage("another-topic", "sending you another car") + m.ID = "m3" + m.User = "u_BAsbaAa" + m.Sender = netip.MustParseAddr("5.6.7.8") + m.Attachment = &attachment{ + Name: "another-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: expires3, + URL: "https://ntfy.sh/file/zakaDHFW.jpg", + } + require.Nil(t, c.AddMessage(m)) + + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + require.Equal(t, "flower for you", messages[0].Message) + require.Equal(t, "flower.jpg", messages[0].Attachment.Name) + require.Equal(t, "image/jpeg", messages[0].Attachment.Type) + require.Equal(t, int64(5000), messages[0].Attachment.Size) + require.Equal(t, expires1, messages[0].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[0].Sender.String()) + + require.Equal(t, "sending you a car", messages[1].Message) + require.Equal(t, "car.jpg", messages[1].Attachment.Name) + require.Equal(t, "image/jpeg", messages[1].Attachment.Type) + require.Equal(t, int64(10000), messages[1].Attachment.Size) + require.Equal(t, expires2, messages[1].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[1].Sender.String()) + + size, err := c.AttachmentBytesUsedBySender("1.2.3.4") + require.Nil(t, err) + require.Equal(t, int64(10000), size) + + size, err = c.AttachmentBytesUsedBySender("5.6.7.8") + require.Nil(t, err) + require.Equal(t, int64(0), size) // Accounted to the user, not the IP! + + size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa") + require.Nil(t, err) + require.Equal(t, int64(20000), size) + }) } -func checkSchemaVersion(t *testing.T, db *sql.DB) { - rows, err := db.Query(`SELECT version FROM schemaVersion`) - require.Nil(t, err) - require.True(t, rows.Next()) +func TestCache_AttachmentsExpired(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m := newDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, c.AddMessage(m)) - var schemaVersion int - require.Nil(t, rows.Scan(&schemaVersion)) - require.Equal(t, currentSchemaVersion, schemaVersion) - require.Nil(t, rows.Close()) + m = newDefaultMessage("mytopic", "message with attachment") + m.ID = "m2" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: time.Now().Add(2 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, c.AddMessage(m)) + + m = newDefaultMessage("mytopic", "message with external attachment") + m.ID = "m3" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Expires: 0, // Unknown! + URL: "https://somedomain.com/car.jpg", + } + require.Nil(t, c.AddMessage(m)) + + m = newDefaultMessage("mytopic2", "message with expired attachment") + m.ID = "m4" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &attachment{ + Name: "expired-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: time.Now().Add(-1 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, c.AddMessage(m)) + + ids, err := c.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "m4", ids[0]) + }) } -func TestMemCache_NopCache(t *testing.T) { - c, _ := newNopCache() - assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) +func TestCache_Sender(t *testing.T) { + runMessageCacheTest(t, func(t *testing.T, c MessageCache) { + m1 := newDefaultMessage("mytopic", "mymessage") + m1.Sender = netip.MustParseAddr("1.2.3.4") + require.Nil(t, c.AddMessage(m1)) - messages, err := c.Messages("mytopic", sinceAllMessages, false) - assert.Nil(t, err) - assert.Empty(t, messages) + m2 := newDefaultMessage("mytopic", "mymessage without sender") + require.Nil(t, c.AddMessage(m2)) - topics, err := c.Topics() - assert.Nil(t, err) - assert.Empty(t, topics) + messages, err := c.Messages("mytopic", sinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) + require.Equal(t, messages[1].Sender, netip.Addr{}) + }) } -func newSqliteTestCache(t *testing.T) *messageCache { - c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false) +func newSqliteTestCache(t *testing.T) *sqliteMessageCache { + c, err := newSqliteMessageCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false) if err != nil { t.Fatal(err) } @@ -698,18 +398,45 @@ func newSqliteTestCacheFile(t *testing.T) string { return filepath.Join(t.TempDir(), "cache.db") } -func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache { - c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - if err != nil { - t.Fatal(err) - } +func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *sqliteMessageCache { + c, err := newSqliteMessageCache(filename, startupQueries, time.Hour, 0, 0, false) + require.Nil(t, err) return c } -func newMemTestCache(t *testing.T) *messageCache { +func newMemTestCache(t *testing.T) MessageCache { c, err := newMemCache() - if err != nil { - t.Fatal(err) - } + require.Nil(t, err) return c } + +func newPgTestCache(t *testing.T) MessageCache { + connectionString := os.Getenv("NTFY_TEST_MESSAGES_CACHE_PG_CONNECTION_STRING") + if connectionString == "" { + t.Skip("Skipping test, because NTFY_TEST_MESSAGES_CACHE_PG_CONNECTION_STRING not set") + } + db, err := sql.Open("postgres", connectionString) + require.Nil(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS messages") + require.Nil(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS stats") + require.Nil(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS schemaVersion") + require.Nil(t, err) + require.Nil(t, db.Close()) + c, err := newPgMessageCache(connectionString, "", 0, 0) + require.Nil(t, err) + return c +} + +func runMessageCacheTest(t *testing.T, f func(t *testing.T, c MessageCache)) { + t.Run(t.Name()+"_sqlite", func(t *testing.T) { + f(t, newSqliteTestCache(t)) + }) + t.Run(t.Name()+"_mem", func(t *testing.T) { + f(t, newMemTestCache(t)) + }) + t.Run(t.Name()+"_pg", func(t *testing.T) { + f(t, newPgTestCache(t)) + }) +} diff --git a/server/server.go b/server/server.go index c1ada4c6..f35399ff 100644 --- a/server/server.go +++ b/server/server.go @@ -53,7 +53,7 @@ type Server struct { messages int64 // Total number of messages (persisted if messageCache enabled) messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! - messageCache *messageCache // Database that stores the messages + messageCache MessageCache // Database that stores the messages webPush *webPushStore // Database that stores web push subscriptions fileCache *fileCache // File system based cache that stores attachments stripe stripeAPI // Stripe API, can be replaced with a mock @@ -226,11 +226,13 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config) (*messageCache, error) { +func createMessageCache(conf *Config) (MessageCache, error) { if conf.CacheDuration == 0 { return newNopCache() + } else if strings.HasPrefix(conf.CacheFile, "postgres:") { + return newPgMessageCache(strings.TrimPrefix(conf.CacheFile, "postgres:"), conf.CacheStartupQueries, conf.CacheBatchSize, conf.CacheBatchTimeout) } else if conf.CacheFile != "" { - return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) + return newSqliteMessageCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) } return newMemCache() } @@ -1525,7 +1527,7 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi return nil } -// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the +// sendOldMessages selects old messages from the sqliteMessageCache and calls sub for each of them. It uses since as the // marker, returning only messages that are newer than the marker. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error { if since.IsNone() { diff --git a/server/server_test.go b/server/server_test.go index 1c800ce6..13fe8c12 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -389,7 +389,7 @@ func TestServer_PublishAt(t *testing.T) { // Update message time to the past fakeTime := time.Now().Add(-10 * time.Second).Unix() - _, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) + _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime) require.Nil(t, err) // Trigger delayed message sending @@ -425,7 +425,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) { // Update message time to the past fakeTime := time.Now().Add(-10 * time.Second).Unix() - _, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) + _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime) require.Nil(t, err) // Trigger delayed message sending @@ -2189,7 +2189,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { require.Nil(t, err) messages = append(messages, newDefaultMessage(topicID, "some message")) } - require.Nil(t, s.messageCache.addMessages(messages)) + require.Nil(t, s.messageCache.AddMessages(messages)) log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) // Update stats diff --git a/server/visitor.go b/server/visitor.go index d542e773..964cfb21 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -53,7 +53,7 @@ const ( // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config - messageCache *messageCache + messageCache MessageCache userManager *user.Manager // May be nil ip netip.Addr // Visitor IP address user *user.User // Only set if authenticated user, otherwise nil @@ -114,7 +114,7 @@ const ( visitorLimitBasisTier = visitorLimitBasis("tier") ) -func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { +func newVisitor(conf *Config, messageCache MessageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { var messages, emails, calls int64 if user != nil { messages = user.Stats.Messages