mirror of
https://github.com/binwiederhier/ntfy.git
synced 2024-12-26 19:52:30 +01:00
fix: removes an issue with topic.Subscribe function not checking duplicate ID
This commit is contained in:
parent
6ad3b2e802
commit
d2fa768151
2 changed files with 47 additions and 3 deletions
|
@ -1,11 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -45,9 +46,23 @@ func newTopic(id string) *topic {
|
||||||
|
|
||||||
// Subscribe subscribes to this topic
|
// Subscribe subscribes to this topic
|
||||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
||||||
|
max_retries := 5
|
||||||
|
retries := 1
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
subscriberID := rand.Int()
|
subscriberID := rand.Int()
|
||||||
|
// simple check for existing id in maps
|
||||||
|
for {
|
||||||
|
_, ok := t.subscribers[subscriberID]
|
||||||
|
if ok && retries <= max_retries {
|
||||||
|
subscriberID = rand.Int()
|
||||||
|
retries++
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.subscribers[subscriberID] = &topicSubscriber{
|
t.subscribers[subscriberID] = &topicSubscriber{
|
||||||
userID: userID, // May be empty
|
userID: userID, // May be empty
|
||||||
subscriber: s,
|
subscriber: s,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"math/rand"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTopic_CancelSubscribers(t *testing.T) {
|
func TestTopic_CancelSubscribers(t *testing.T) {
|
||||||
|
@ -39,3 +41,30 @@ func TestTopic_Keepalive(t *testing.T) {
|
||||||
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
|
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
|
||||||
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
|
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTopic_Subscribe_duplicateID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
to := newTopic("mytopic")
|
||||||
|
|
||||||
|
// fix random seed to force same number generation
|
||||||
|
rand.Seed(1)
|
||||||
|
a := rand.Int()
|
||||||
|
to.subscribers[a] = &topicSubscriber{
|
||||||
|
userID: "a",
|
||||||
|
subscriber: nil,
|
||||||
|
cancel: func() {},
|
||||||
|
}
|
||||||
|
|
||||||
|
subFn := func(v *visitor, msg *message) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// force rand.Int to generate the same id once more
|
||||||
|
rand.Seed(1)
|
||||||
|
id := to.Subscribe(subFn, "b", func() {})
|
||||||
|
res := to.subscribers[id]
|
||||||
|
|
||||||
|
require.False(t, id == a)
|
||||||
|
require.True(t, res.userID == "b")
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue