diff --git a/server/server.go b/server/server.go index 17362b1e..5ddbcac3 100644 --- a/server/server.go +++ b/server/server.go @@ -9,7 +9,12 @@ import ( "encoding/json" "errors" "fmt" + "github.com/emersion/go-smtp" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + "heckel.io/ntfy/log" "heckel.io/ntfy/user" + "heckel.io/ntfy/util" "io" "net" "net/http" @@ -25,13 +30,6 @@ import ( "sync" "time" "unicode/utf8" - - "heckel.io/ntfy/log" - - "github.com/emersion/go-smtp" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" - "heckel.io/ntfy/util" ) /* @@ -43,7 +41,6 @@ import ( - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) - MEDIUM: Reservation table delete button: dialog "keep or delete messages?" -- MEDIUM: Tests for remaining payment endpoints - LOW: UI: Flickering upgrade banner when logging in - LOW: JS constants - LOW: Payments reconciliation process diff --git a/server/server_payments_test.go b/server/server_payments_test.go index d1c8de4c..c1903812 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -524,6 +524,70 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active( require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID)) } +func TestPayments_Webhook_Subscription_Deleted(t *testing.T) { + // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is + // updated (all Stripe fields are deleted, and the tier is removed). + // + // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call. + + stripeMock := &testStripeAPI{} + defer stripeMock.AssertExpectations(t) + + c := newTestConfigWithAuthFile(t) + c.StripeSecretKey = "secret key" + c.StripeWebhookKey = "webhook key" + s := newTestServer(t, c) + s.stripe = stripeMock + + // Define how the mock should react + stripeMock. + On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key"). + Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil) + + // Create a user with a Stripe subscription and 3 reservations + require.Nil(t, s.userManager.CreateTier(&user.Tier{ + ID: "ti_1", + Code: "pro", + StripePriceID: "price_1234", + ReservationLimit: 1, + })) + require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) + require.Nil(t, s.userManager.ChangeTier("phil", "pro")) + require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll)) + + // Add billing details + u, err := s.userManager.User("phil") + require.Nil(t, err) + require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{ + StripeCustomerID: "acct_5555", + StripeSubscriptionID: "sub_1234", + StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue, + StripeSubscriptionPaidUntil: time.Unix(123, 0), + StripeSubscriptionCancelAt: time.Unix(0, 0), + })) + + // Call the webhook: This does all the magic + rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{ + "Stripe-Signature": "stripe signature", + }) + require.Equal(t, 200, rr.Code) + + // Verify that database columns were updated + u, err = s.userManager.User("phil") + require.Nil(t, err) + require.Nil(t, u.Tier) + require.Equal(t, "acct_5555", u.Billing.StripeCustomerID) + require.Equal(t, "", u.Billing.StripeSubscriptionID) + require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus) + require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix()) + require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) + + // Verify that reservations were deleted + r, err := s.userManager.Reservations("phil") + require.Nil(t, err) + require.Equal(t, 0, len(r)) +} + func TestPayments_Subscription_Update_Different_Tier(t *testing.T) { stripeMock := &testStripeAPI{} defer stripeMock.AssertExpectations(t) @@ -739,3 +803,26 @@ const subscriptionUpdatedEventJSON = ` } } }` + +const subscriptionDeletedEventJSON = ` +{ + "type": "customer.subscription.deleted", + "data": { + "object": { + "id": "sub_1234", + "customer": "acct_5555", + "status": "active", + "current_period_end": 1674268231, + "cancel_at": 1674299999, + "items": { + "data": [ + { + "price": { + "id": "price_1234" + } + } + ] + } + } + } +}` diff --git a/server/server_test.go b/server/server_test.go index 06fe4826..915b00f3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1872,11 +1872,11 @@ func subscribe(t *testing.T, s *Server, url string, rr *httptest.ResponseRecorde done <- true }() cancelAndWaitForDone := func() { - time.Sleep(100 * time.Millisecond) + time.Sleep(200 * time.Millisecond) cancel() <-done } - time.Sleep(100 * time.Millisecond) + time.Sleep(200 * time.Millisecond) return cancelAndWaitForDone } diff --git a/util/util_test.go b/util/util_test.go index 10381f38..35b4f790 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -1,12 +1,14 @@ package util import ( + "errors" "io" "net/netip" "os" "path/filepath" "strings" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -212,3 +214,51 @@ func TestReadJSONWithLimit_NoAllowEmpty(t *testing.T) { _, err := UnmarshalJSONWithLimit[testJSON](io.NopCloser(strings.NewReader(` `)), 10, false) require.Equal(t, ErrUnmarshalJSON, err) } + +func TestRetry_Succeeds(t *testing.T) { + start := time.Now() + delays, i := []time.Duration{10 * time.Millisecond, 50 * time.Millisecond, 100 * time.Millisecond, time.Second}, 0 + fn := func() (*int, error) { + i++ + if i < len(delays) { + return nil, errors.New("error") + } + return Int(99), nil + } + result, err := Retry[int](fn, delays...) + require.Nil(t, err) + require.Equal(t, 99, *result) + require.True(t, time.Since(start).Milliseconds() > 150) +} + +func TestRetry_Fails(t *testing.T) { + fn := func() (*int, error) { + return nil, errors.New("fails") + } + _, err := Retry[int](fn, 10*time.Millisecond) + require.Error(t, err) +} + +func TestMinMax(t *testing.T) { + require.Equal(t, 10, MinMax(9, 10, 99)) + require.Equal(t, 99, MinMax(100, 10, 99)) + require.Equal(t, 50, MinMax(50, 10, 99)) +} + +func TestPointerFunctions(t *testing.T) { + i, s, ti := Int(99), String("abc"), Time(time.Unix(99, 0)) + require.Equal(t, 99, *i) + require.Equal(t, "abc", *s) + require.Equal(t, time.Unix(99, 0), *ti) +} + +func TestMaybeMarshalJSON(t *testing.T) { + require.Equal(t, `"aa"`, MaybeMarshalJSON("aa")) + require.Equal(t, `[ + "aa", + "bb" +]`, MaybeMarshalJSON([]string{"aa", "bb"})) + require.Equal(t, "", MaybeMarshalJSON(func() {})) + require.Equal(t, `"`+strings.Repeat("x", 4999), MaybeMarshalJSON(strings.Repeat("x", 6000))) + +}