From 2179fb2c1826ae650fe003ea6357b40704d4d3e7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Oct 2024 00:57:27 +0300 Subject: [PATCH] hicli/sync: recalculate unreads on redaction --- pkg/hicli/database/event.go | 6 +++--- pkg/hicli/database/unread.go | 2 +- pkg/hicli/sync.go | 7 +++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pkg/hicli/database/event.go b/pkg/hicli/database/event.go index a857bcd..c7147b9 100644 --- a/pkg/hicli/database/event.go +++ b/pkg/hicli/database/event.go @@ -467,15 +467,15 @@ func (e *Event) sqlVariables() []any { } func (e *Event) GetNonPushUnreadType() UnreadType { - if e.RelationType == event.RelReplace { + if e.RelationType == event.RelReplace || e.RedactedBy != "" { return UnreadTypeNone } switch e.Type { - case event.EventMessage.Type, event.EventSticker.Type: + case event.EventMessage.Type, event.EventSticker.Type, event.EventUnstablePollStart.Type: return UnreadTypeNormal case event.EventEncrypted.Type: switch e.DecryptedType { - case event.EventMessage.Type, event.EventSticker.Type: + case event.EventMessage.Type, event.EventSticker.Type, event.EventUnstablePollStart.Type: return UnreadTypeNormal } } diff --git a/pkg/hicli/database/unread.go b/pkg/hicli/database/unread.go index b555a05..3faf078 100644 --- a/pkg/hicli/database/unread.go +++ b/pkg/hicli/database/unread.go @@ -31,7 +31,7 @@ const ( JOIN event ON receipt.event_id=event.event_id WHERE receipt.room_id = $1 AND receipt.user_id = $2 ) - ) + ) AND unread_type > 0 AND redacted_by IS NULL ` ) diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index 54379c1..2768dc0 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -489,7 +489,7 @@ func (h *HiClient) processStateAndTimeline( decryptionQueue := make(map[id.SessionID]*database.SessionRequest) allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) newNotifications := make([]SyncNotification, 0) - recalculatePreviewEvent := false + var recalculatePreviewEvent, unreadMessagesWereMaybeRedacted bool var newUnreadCounts database.UnreadCounts addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { if rowID != 0 { @@ -513,6 +513,9 @@ func (h *HiClient) processStateAndTimeline( if dbEvt == nil { return nil } + if dbEvt.UnreadType > 0 { + unreadMessagesWereMaybeRedacted = true + } if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { _, err = addOldEvent(0, dbEvt.RelatesTo) if err != nil { @@ -677,7 +680,7 @@ func (h *HiClient) processStateAndTimeline( return fmt.Errorf("failed to save receipts: %w", err) } } - if len(newOwnReceipts) > 0 && newUnreadCounts.IsZero() { + if !room.UnreadCounts.IsZero() && ((len(newOwnReceipts) > 0 && newUnreadCounts.IsZero()) || unreadMessagesWereMaybeRedacted) { updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID) if err != nil { return fmt.Errorf("failed to recalculate unread counts: %w", err)