mirror of
				https://github.com/binwiederhier/ntfy.git
				synced 2025-10-31 13:02:24 +01:00 
			
		
		
		
	Combine things, move stuff
This commit is contained in:
		
							parent
							
								
									707c58a120
								
							
						
					
					
						commit
						2b6363474e
					
				
					 5 changed files with 231 additions and 187 deletions
				
			
		
							
								
								
									
										165
									
								
								server/server.go
									
										
									
									
									
								
							
							
						
						
									
										165
									
								
								server/server.go
									
										
									
									
									
								
							|  | @ -32,9 +32,6 @@ import ( | |||
| 	"unicode/utf8" | ||||
| ) | ||||
| 
 | ||||
| // TODO add "max messages in a topic" limit | ||||
| // TODO implement "since=<ID>" | ||||
| 
 | ||||
| // Server is the main server, providing the UI and API for ntfy | ||||
| type Server struct { | ||||
| 	config       *Config | ||||
|  | @ -59,25 +56,6 @@ type indexPage struct { | |||
| 	CacheDuration time.Duration | ||||
| } | ||||
| 
 | ||||
| type sinceTime time.Time | ||||
| 
 | ||||
| func (t sinceTime) IsAll() bool { | ||||
| 	return t == sinceAllMessages | ||||
| } | ||||
| 
 | ||||
| func (t sinceTime) IsNone() bool { | ||||
| 	return t == sinceNoMessages | ||||
| } | ||||
| 
 | ||||
| func (t sinceTime) Time() time.Time { | ||||
| 	return time.Time(t) | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	sinceAllMessages = sinceTime(time.Unix(0, 0)) | ||||
| 	sinceNoMessages  = sinceTime(time.Unix(1, 0)) | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	topicRegex       = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /! | ||||
| 	topicPathRegex   = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! | ||||
|  | @ -117,7 +95,6 @@ const ( | |||
| 	firebaseControlTopic     = "~control"                // See Android if changed | ||||
| 	emptyMessageBody         = "triggered"               // Used if message body is empty | ||||
| 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment | ||||
| 	fcmMessageLimit          = 4000                      // see maybeTruncateFCMMessage for details | ||||
| ) | ||||
| 
 | ||||
| // WebSocket constants | ||||
|  | @ -232,25 +209,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) { | |||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // maybeTruncateFCMMessage performs best-effort truncation of FCM messages. | ||||
| // The docs say the limit is 4000 characters, but during testing it wasn't quite clear | ||||
| // what fields matter; so we're just capping the serialized JSON to 4000 bytes. | ||||
| func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message { | ||||
| 	s, err := json.Marshal(m) | ||||
| 	if err != nil { | ||||
| 		return m | ||||
| 	} | ||||
| 	if len(s) > fcmMessageLimit { | ||||
| 		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ... | ||||
| 		message, ok := m.Data["message"] | ||||
| 		if ok && len(message) > over { | ||||
| 			m.Data["truncated"] = "1" | ||||
| 			m.Data["message"] = message[:len(message)-over] | ||||
| 		} | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts | ||||
| // a manager go routine to print stats and prune messages. | ||||
| func (s *Server) Run() error { | ||||
|  | @ -391,7 +349,7 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { | |||
| } | ||||
| 
 | ||||
| func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error { | ||||
| 	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too! | ||||
| 	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too! | ||||
| 	if unifiedpush { | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests | ||||
|  | @ -497,13 +455,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito | |||
| 	if err := json.NewEncoder(w).Encode(m); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	s.inc(&s.messages) | ||||
| 	s.mu.Lock() | ||||
| 	s.messages++ | ||||
| 	s.mu.Unlock() | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) { | ||||
| 	cache = readParam(r, "x-cache", "cache") != "no" | ||||
| 	firebase = readParam(r, "x-firebase", "firebase") != "no" | ||||
| 	cache = readBoolParam(r, true, "x-cache", "cache") | ||||
| 	firebase = readBoolParam(r, true, "x-firebase", "firebase") | ||||
| 	m.Title = readParam(r, "x-title", "title", "t") | ||||
| 	m.Click = readParam(r, "x-click", "click") | ||||
| 	filename := readParam(r, "x-filename", "filename", "file", "f") | ||||
|  | @ -574,29 +534,13 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca | |||
| 		} | ||||
| 		m.Time = delay.Unix() | ||||
| 	} | ||||
| 	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too! | ||||
| 	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too! | ||||
| 	if unifiedpush { | ||||
| 		firebase = false | ||||
| 	} | ||||
| 	return cache, firebase, email, nil | ||||
| } | ||||
| 
 | ||||
| func readParam(r *http.Request, names ...string) string { | ||||
| 	for _, name := range names { | ||||
| 		value := r.Header.Get(name) | ||||
| 		if value != "" { | ||||
| 			return strings.TrimSpace(value) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, name := range names { | ||||
| 		value := r.URL.Query().Get(strings.ToLower(name)) | ||||
| 		if value != "" { | ||||
| 			return strings.TrimSpace(value) | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message. | ||||
| // | ||||
| // 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic | ||||
|  | @ -680,7 +624,7 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v * | |||
| 		} | ||||
| 		return buf.String(), nil | ||||
| 	} | ||||
| 	return s.handleSubscribe(w, r, v, "json", "application/x-ndjson", encoder) | ||||
| 	return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
|  | @ -694,7 +638,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v | |||
| 		} | ||||
| 		return fmt.Sprintf("data: %s\n", buf.String()), nil | ||||
| 	} | ||||
| 	return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder) | ||||
| 	return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error { | ||||
|  | @ -704,33 +648,25 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v | |||
| 		} | ||||
| 		return "\n", nil // "keepalive" and "open" events just send an empty line | ||||
| 	} | ||||
| 	return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder) | ||||
| 	return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error { | ||||
| func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error { | ||||
| 	if err := v.SubscriptionAllowed(); err != nil { | ||||
| 		return errHTTPTooManyRequestsLimitSubscriptions | ||||
| 	} | ||||
| 	defer v.RemoveSubscription() | ||||
| 	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack | ||||
| 	topicIDs := util.SplitNoEmpty(topicsStr, ",") | ||||
| 	topics, err := s.topicsFromIDs(topicIDs...) | ||||
| 	topics, topicsStr, err := s.topicsFromPath(r.URL.Path) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	poll := readParam(r, "x-poll", "poll", "po") == "1" | ||||
| 	scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1" | ||||
| 	since, err := parseSince(r, poll) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r) | ||||
| 	poll, since, scheduled, filters, err := parseSubscribeParams(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	var wlock sync.Mutex | ||||
| 	sub := func(msg *message) error { | ||||
| 		if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) { | ||||
| 		if !filters.Pass(msg) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		m, err := encoder(msg) | ||||
|  | @ -785,19 +721,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 		return errHTTPTooManyRequestsLimitSubscriptions | ||||
| 	} | ||||
| 	defer v.RemoveSubscription() | ||||
| 	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/ws") // Hack | ||||
| 	topicIDs := util.SplitNoEmpty(topicsStr, ",") | ||||
| 	topics, err := s.topicsFromIDs(topicIDs...) | ||||
| 	topics, topicsStr, err := s.topicsFromPath(r.URL.Path) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	poll := readParam(r, "x-poll", "poll", "po") == "1" | ||||
| 	scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1" | ||||
| 	since, err := parseSince(r, poll) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r) | ||||
| 	poll, since, scheduled, filters, err := parseSubscribeParams(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -850,7 +778,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 		} | ||||
| 	}) | ||||
| 	sub := func(msg *message) error { | ||||
| 		if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) { | ||||
| 		if !filters.Pass(msg) { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { | ||||
|  | @ -884,42 +812,18 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi | |||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) { | ||||
| 	messageFilter = readParam(r, "x-message", "message", "m") | ||||
| 	titleFilter = readParam(r, "x-title", "title", "t") | ||||
| 	tagsFilter = util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",") | ||||
| 	priorityFilter = make([]int, 0) | ||||
| 	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") { | ||||
| 		priority, err := util.ParsePriority(p) | ||||
| func parseSubscribeParams(r *http.Request) (poll bool, since sinceTime, scheduled bool, filters *queryFilter, err error) { | ||||
| 	poll = readBoolParam(r, false, "x-poll", "poll", "po") | ||||
| 	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") | ||||
| 	since, err = parseSince(r, poll) | ||||
| 	if err != nil { | ||||
| 			return "", "", nil, nil, err | ||||
| 		} | ||||
| 		priorityFilter = append(priorityFilter, priority) | ||||
| 	} | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string) bool { | ||||
| 	if msg.Event != messageEvent { | ||||
| 		return true // filters only apply to messages | ||||
| 	filters, err = parseQueryFilters(r) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if messageFilter != "" && msg.Message != messageFilter { | ||||
| 		return false | ||||
| 	} | ||||
| 	if titleFilter != "" && msg.Title != titleFilter { | ||||
| 		return false | ||||
| 	} | ||||
| 	messagePriority := msg.Priority | ||||
| 	if messagePriority == 0 { | ||||
| 		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0) | ||||
| 	} | ||||
| 	if len(priorityFilter) > 0 && !util.InIntList(priorityFilter, messagePriority) { | ||||
| 		return false | ||||
| 	} | ||||
| 	if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error { | ||||
|  | @ -980,6 +884,19 @@ func (s *Server) topicFromPath(path string) (*topic, error) { | |||
| 	return topics[0], nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) topicsFromPath(path string) ([]*topic, string, error) { | ||||
| 	parts := strings.Split(path, "/") | ||||
| 	if len(parts) < 2 { | ||||
| 		return nil, "", errHTTPBadRequestTopicInvalid | ||||
| 	} | ||||
| 	topicIDs := util.SplitNoEmpty(parts[1], ",") | ||||
| 	topics, err := s.topicsFromIDs(topicIDs...) | ||||
| 	if err != nil { | ||||
| 		return nil, "", errHTTPBadRequestTopicInvalid | ||||
| 	} | ||||
| 	return topics, parts[1], nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
|  | @ -1180,9 +1097,3 @@ func (s *Server) visitor(r *http.Request) *visitor { | |||
| 	v.Keepalive() | ||||
| 	return v | ||||
| } | ||||
| 
 | ||||
| func (s *Server) inc(counter *int64) { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	*counter++ | ||||
| } | ||||
|  |  | |||
|  | @ -4,7 +4,6 @@ import ( | |||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"firebase.google.com/go/messaging" | ||||
| 	"fmt" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"heckel.io/ntfy/util" | ||||
|  | @ -624,63 +623,6 @@ func TestServer_UnifiedPushDiscovery(t *testing.T) { | |||
| 	require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String()) | ||||
| } | ||||
| 
 | ||||
| func TestServer_MaybeTruncateFCMMessage(t *testing.T) { | ||||
| 	origMessage := strings.Repeat("this is a long string", 300) | ||||
| 	origFCMMessage := &messaging.Message{ | ||||
| 		Topic: "mytopic", | ||||
| 		Data: map[string]string{ | ||||
| 			"id":       "abcdefg", | ||||
| 			"time":     "1641324761", | ||||
| 			"event":    "message", | ||||
| 			"topic":    "mytopic", | ||||
| 			"priority": "0", | ||||
| 			"tags":     "", | ||||
| 			"title":    "", | ||||
| 			"message":  origMessage, | ||||
| 		}, | ||||
| 		Android: &messaging.AndroidConfig{ | ||||
| 			Priority: "high", | ||||
| 		}, | ||||
| 	} | ||||
| 	origMessageLength := len(origFCMMessage.Data["message"]) | ||||
| 	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) | ||||
| 	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition | ||||
| 
 | ||||
| 	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) | ||||
| 	truncatedMessageLength := len(truncatedFCMMessage.Data["message"]) | ||||
| 	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage) | ||||
| 	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage)) | ||||
| 	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"]) | ||||
| 	require.NotEqual(t, origMessageLength, truncatedMessageLength) | ||||
| } | ||||
| 
 | ||||
| func TestServer_MaybeTruncateFCMMessage_NotTooLong(t *testing.T) { | ||||
| 	origMessage := "not really a long string" | ||||
| 	origFCMMessage := &messaging.Message{ | ||||
| 		Topic: "mytopic", | ||||
| 		Data: map[string]string{ | ||||
| 			"id":       "abcdefg", | ||||
| 			"time":     "1641324761", | ||||
| 			"event":    "message", | ||||
| 			"topic":    "mytopic", | ||||
| 			"priority": "0", | ||||
| 			"tags":     "", | ||||
| 			"title":    "", | ||||
| 			"message":  origMessage, | ||||
| 		}, | ||||
| 	} | ||||
| 	origMessageLength := len(origFCMMessage.Data["message"]) | ||||
| 	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) | ||||
| 	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition | ||||
| 
 | ||||
| 	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) | ||||
| 	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"]) | ||||
| 	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage) | ||||
| 	require.Equal(t, origMessageLength, notTruncatedMessageLength) | ||||
| 	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) | ||||
| 	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) | ||||
| } | ||||
| 
 | ||||
| func TestServer_PublishAttachment(t *testing.T) { | ||||
| 	content := util.RandomString(5000) // > 4096 | ||||
| 	s := newTestServer(t, newTestConfig(t)) | ||||
|  |  | |||
|  | @ -2,6 +2,7 @@ package server | |||
| 
 | ||||
| import ( | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
|  | @ -70,3 +71,72 @@ func newKeepaliveMessage(topic string) *message { | |||
| func newDefaultMessage(topic, msg string) *message { | ||||
| 	return newMessage(messageEvent, topic, msg) | ||||
| } | ||||
| 
 | ||||
| type sinceTime time.Time | ||||
| 
 | ||||
| func (t sinceTime) IsAll() bool { | ||||
| 	return t == sinceAllMessages | ||||
| } | ||||
| 
 | ||||
| func (t sinceTime) IsNone() bool { | ||||
| 	return t == sinceNoMessages | ||||
| } | ||||
| 
 | ||||
| func (t sinceTime) Time() time.Time { | ||||
| 	return time.Time(t) | ||||
| } | ||||
| 
 | ||||
| var ( | ||||
| 	sinceAllMessages = sinceTime(time.Unix(0, 0)) | ||||
| 	sinceNoMessages  = sinceTime(time.Unix(1, 0)) | ||||
| ) | ||||
| 
 | ||||
| type queryFilter struct { | ||||
| 	Message  string | ||||
| 	Title    string | ||||
| 	Tags     []string | ||||
| 	Priority []int | ||||
| } | ||||
| 
 | ||||
| func parseQueryFilters(r *http.Request) (*queryFilter, error) { | ||||
| 	messageFilter := readParam(r, "x-message", "message", "m") | ||||
| 	titleFilter := readParam(r, "x-title", "title", "t") | ||||
| 	tagsFilter := util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",") | ||||
| 	priorityFilter := make([]int, 0) | ||||
| 	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") { | ||||
| 		priority, err := util.ParsePriority(p) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		priorityFilter = append(priorityFilter, priority) | ||||
| 	} | ||||
| 	return &queryFilter{ | ||||
| 		Message:  messageFilter, | ||||
| 		Title:    titleFilter, | ||||
| 		Tags:     tagsFilter, | ||||
| 		Priority: priorityFilter, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (q *queryFilter) Pass(msg *message) bool { | ||||
| 	if msg.Event != messageEvent { | ||||
| 		return true // filters only apply to messages | ||||
| 	} | ||||
| 	if q.Message != "" && msg.Message != q.Message { | ||||
| 		return false | ||||
| 	} | ||||
| 	if q.Title != "" && msg.Title != q.Title { | ||||
| 		return false | ||||
| 	} | ||||
| 	messagePriority := msg.Priority | ||||
| 	if messagePriority == 0 { | ||||
| 		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0) | ||||
| 	} | ||||
| 	if len(q.Priority) > 0 && !util.InIntList(q.Priority, messagePriority) { | ||||
| 		return false | ||||
| 	} | ||||
| 	if len(q.Tags) > 0 && !util.InStringListAll(msg.Tags, q.Tags) { | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
							
								
								
									
										55
									
								
								server/util.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								server/util.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,55 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"firebase.google.com/go/messaging" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	fcmMessageLimit = 4000 | ||||
| ) | ||||
| 
 | ||||
| // maybeTruncateFCMMessage performs best-effort truncation of FCM messages. | ||||
| // The docs say the limit is 4000 characters, but during testing it wasn't quite clear | ||||
| // what fields matter; so we're just capping the serialized JSON to 4000 bytes. | ||||
| func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message { | ||||
| 	s, err := json.Marshal(m) | ||||
| 	if err != nil { | ||||
| 		return m | ||||
| 	} | ||||
| 	if len(s) > fcmMessageLimit { | ||||
| 		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ... | ||||
| 		message, ok := m.Data["message"] | ||||
| 		if ok && len(message) > over { | ||||
| 			m.Data["truncated"] = "1" | ||||
| 			m.Data["message"] = message[:len(message)-over] | ||||
| 		} | ||||
| 	} | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { | ||||
| 	value := strings.ToLower(readParam(r, names...)) | ||||
| 	if value == "" { | ||||
| 		return defaultValue | ||||
| 	} | ||||
| 	return value == "1" || value == "yes" || value == "true" | ||||
| } | ||||
| 
 | ||||
| func readParam(r *http.Request, names ...string) string { | ||||
| 	for _, name := range names { | ||||
| 		value := r.Header.Get(name) | ||||
| 		if value != "" { | ||||
| 			return strings.TrimSpace(value) | ||||
| 		} | ||||
| 	} | ||||
| 	for _, name := range names { | ||||
| 		value := r.URL.Query().Get(strings.ToLower(name)) | ||||
| 		if value != "" { | ||||
| 			return strings.TrimSpace(value) | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
							
								
								
									
										66
									
								
								server/util_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								server/util_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,66 @@ | |||
| package server | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"firebase.google.com/go/messaging" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestMaybeTruncateFCMMessage(t *testing.T) { | ||||
| 	origMessage := strings.Repeat("this is a long string", 300) | ||||
| 	origFCMMessage := &messaging.Message{ | ||||
| 		Topic: "mytopic", | ||||
| 		Data: map[string]string{ | ||||
| 			"id":       "abcdefg", | ||||
| 			"time":     "1641324761", | ||||
| 			"event":    "message", | ||||
| 			"topic":    "mytopic", | ||||
| 			"priority": "0", | ||||
| 			"tags":     "", | ||||
| 			"title":    "", | ||||
| 			"message":  origMessage, | ||||
| 		}, | ||||
| 		Android: &messaging.AndroidConfig{ | ||||
| 			Priority: "high", | ||||
| 		}, | ||||
| 	} | ||||
| 	origMessageLength := len(origFCMMessage.Data["message"]) | ||||
| 	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) | ||||
| 	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition | ||||
| 
 | ||||
| 	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) | ||||
| 	truncatedMessageLength := len(truncatedFCMMessage.Data["message"]) | ||||
| 	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage) | ||||
| 	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage)) | ||||
| 	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"]) | ||||
| 	require.NotEqual(t, origMessageLength, truncatedMessageLength) | ||||
| } | ||||
| 
 | ||||
| func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) { | ||||
| 	origMessage := "not really a long string" | ||||
| 	origFCMMessage := &messaging.Message{ | ||||
| 		Topic: "mytopic", | ||||
| 		Data: map[string]string{ | ||||
| 			"id":       "abcdefg", | ||||
| 			"time":     "1641324761", | ||||
| 			"event":    "message", | ||||
| 			"topic":    "mytopic", | ||||
| 			"priority": "0", | ||||
| 			"tags":     "", | ||||
| 			"title":    "", | ||||
| 			"message":  origMessage, | ||||
| 		}, | ||||
| 	} | ||||
| 	origMessageLength := len(origFCMMessage.Data["message"]) | ||||
| 	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage) | ||||
| 	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition | ||||
| 
 | ||||
| 	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage) | ||||
| 	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"]) | ||||
| 	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage) | ||||
| 	require.Equal(t, origMessageLength, notTruncatedMessageLength) | ||||
| 	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage)) | ||||
| 	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"]) | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Philipp Heckel
						Philipp Heckel