diff --git a/pkg/hicli/database/event.go b/pkg/hicli/database/event.go index 5f82977..a857bcd 100644 --- a/pkg/hicli/database/event.go +++ b/pkg/hicli/database/event.go @@ -284,20 +284,6 @@ type LocalContent struct { HTMLVersion int `json:"html_version,omitempty"` } -type UnreadType int - -func (ut UnreadType) Is(flag UnreadType) bool { - return ut&flag != 0 -} - -const ( - UnreadTypeNone UnreadType = 0b0000 - UnreadTypeNormal UnreadType = 0b0001 - UnreadTypeNotify UnreadType = 0b0010 - UnreadTypeHighlight UnreadType = 0b0100 - UnreadTypeSound UnreadType = 0b1000 -) - type Event struct { RowID EventRowID `json:"rowid"` TimelineRowID TimelineRowID `json:"timeline_rowid"` diff --git a/pkg/hicli/database/room.go b/pkg/hicli/database/room.go index fe38f2f..92fcd0e 100644 --- a/pkg/hicli/database/room.go +++ b/pkg/hicli/database/room.go @@ -145,11 +145,9 @@ type Room struct { EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` HasMemberList bool `json:"has_member_list"` - PreviewEventRowID EventRowID `json:"preview_event_rowid"` - SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` - UnreadHighlights int `json:"unread_highlights"` - UnreadNotifications int `json:"unread_notifications"` - UnreadMessages int `json:"unread_messages"` + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + UnreadCounts PrevBatch string `json:"prev_batch"` } diff --git a/pkg/hicli/database/unread.go b/pkg/hicli/database/unread.go new file mode 100644 index 0000000..b555a05 --- /dev/null +++ b/pkg/hicli/database/unread.go @@ -0,0 +1,84 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "maunium.net/go/mautrix/id" +) + +const ( + // TODO find out if this needs to be wrapped in another query that limits the number of events it evaluates + // (or maybe the timeline store just shouldn't be allowed to grow that big?) + calculateUnreadsQuery = ` + SELECT + COALESCE(SUM(CASE WHEN unread_type & 0100 THEN 1 ELSE 0 END), 0) AS highlights, + COALESCE(SUM(CASE WHEN unread_type & 0010 THEN 1 ELSE 0 END), 0) AS notifications, + COALESCE(SUM(CASE WHEN unread_type & 0001 THEN 1 ELSE 0 END), 0) AS messages + FROM timeline + JOIN event ON event.rowid = timeline.event_rowid + WHERE timeline.room_id = $1 AND timeline.rowid > ( + SELECT MAX(rowid) + FROM timeline + WHERE room_id = $1 AND event_rowid IN ( + SELECT event.rowid + FROM receipt + JOIN event ON receipt.event_id=event.event_id + WHERE receipt.room_id = $1 AND receipt.user_id = $2 + ) + ) + ` +) + +func (rq *RoomQuery) CalculateUnreads(ctx context.Context, roomID id.RoomID, userID id.UserID) (uc UnreadCounts, err error) { + err = rq.GetDB().QueryRow(ctx, calculateUnreadsQuery, roomID, userID). + Scan(&uc.UnreadHighlights, &uc.UnreadNotifications, &uc.UnreadMessages) + return +} + +type UnreadType int + +func (ut UnreadType) Is(flag UnreadType) bool { + return ut&flag != 0 +} + +const ( + UnreadTypeNone UnreadType = 0b0000 + UnreadTypeNormal UnreadType = 0b0001 + UnreadTypeNotify UnreadType = 0b0010 + UnreadTypeHighlight UnreadType = 0b0100 + UnreadTypeSound UnreadType = 0b1000 +) + +type UnreadCounts struct { + UnreadHighlights int `json:"unread_highlights"` + UnreadNotifications int `json:"unread_notifications"` + UnreadMessages int `json:"unread_messages"` +} + +func (uc *UnreadCounts) IsZero() bool { + return uc.UnreadHighlights == 0 && uc.UnreadNotifications == 0 && uc.UnreadMessages == 0 +} + +func (uc *UnreadCounts) Add(other UnreadCounts) { + uc.UnreadHighlights += other.UnreadHighlights + uc.UnreadNotifications += other.UnreadNotifications + uc.UnreadMessages += other.UnreadMessages +} + +func (uc *UnreadCounts) AddOne(ut UnreadType) { + if ut.Is(UnreadTypeNormal) { + uc.UnreadMessages++ + } + if ut.Is(UnreadTypeNotify) { + uc.UnreadNotifications++ + } + if ut.Is(UnreadTypeHighlight) { + uc.UnreadHighlights++ + } +} diff --git a/pkg/hicli/send.go b/pkg/hicli/send.go index 2e269c7..9529259 100644 --- a/pkg/hicli/send.go +++ b/pkg/hicli/send.go @@ -133,6 +133,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ Reactions: map[string]int{}, LastEditRowID: ptr.Ptr(database.EventRowID(0)), } + dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix()) _, err = h.DB.Event.Insert(ctx, dbEvt) if err != nil { return nil, fmt.Errorf("failed to insert event into database: %w", err) diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index d2916cf..47471a1 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -156,11 +156,6 @@ func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*databa return receiptList, newOwnReceipts } -type receiptsToSave struct { - roomID id.RoomID - receipts []*database.Receipt -} - func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { existingRoomData, err := h.DB.Room.Get(ctx, roomID) if err != nil { @@ -181,7 +176,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) } } - var receipts []receiptsToSave + var receiptsList []*database.Receipt var newOwnReceipts []id.EventID for _, evt := range room.Ephemeral.Events { evt.Type.Class = event.EphemeralEventType @@ -192,9 +187,9 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, } switch evt.Type { case event.EphemeralEventReceipt: - var receiptsList []*database.Receipt - receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt()) - receipts = append(receipts, receiptsToSave{roomID, receiptsList}) + list, ownList := h.receiptsToList(evt.Content.AsReceipt()) + receiptsList = append(receiptsList, list...) + newOwnReceipts = append(newOwnReceipts, ownList...) case event.EphemeralEventTyping: go h.EventHandler(&Typing{ RoomID: roomID, @@ -202,16 +197,18 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, }) } } - err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications) + err = h.processStateAndTimeline( + ctx, + existingRoomData, + &room.State, + &room.Timeline, + &room.Summary, + receiptsList, + newOwnReceipts, + ) if err != nil { return err } - for _, rs := range receipts { - err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...) - if err != nil { - return fmt.Errorf("failed to save receipts: %w", err) - } - } return nil } @@ -222,8 +219,9 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro } else if existingRoomData == nil { return nil } - // TODO delete room instead of processing? - return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) + // TODO delete room + return nil + //return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) } func isDecryptionErrorRetryable(err error) bool { @@ -333,7 +331,7 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab } func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { - if evt.Type != event.EventMessage && evt.Type != event.EventSticker { + if evt.Type != event.EventMessage { return nil } _ = evt.Content.ParseRaw(evt.Type) @@ -468,21 +466,15 @@ func (h *HiClient) processStateAndTimeline( state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary, + receipts []*database.Receipt, newOwnReceipts []id.EventID, - serverNotificationCounts *mautrix.UnreadNotificationCounts, ) error { updatedRoom := &database.Room{ ID: room.ID, - SortingTimestamp: room.SortingTimestamp, - NameQuality: room.NameQuality, - UnreadHighlights: room.UnreadHighlights, - UnreadNotifications: room.UnreadNotifications, - UnreadMessages: room.UnreadMessages, - } - if serverNotificationCounts != nil { - updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount - updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount + SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + UnreadCounts: room.UnreadCounts, } heroesChanged := false if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { @@ -498,6 +490,7 @@ func (h *HiClient) processStateAndTimeline( allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) newNotifications := make([]SyncNotification, 0) recalculatePreviewEvent := false + var newUnreadCounts database.UnreadCounts addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { if rowID != 0 { dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID) @@ -538,11 +531,14 @@ func (h *HiClient) processStateAndTimeline( if err != nil { return -1, err } - if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) { - newNotifications = append(newNotifications, SyncNotification{ - RowID: dbEvt.RowID, - Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), - }) + if isUnread { + if dbEvt.UnreadType.Is(database.UnreadTypeNotify) { + newNotifications = append(newNotifications, SyncNotification{ + RowID: dbEvt.RowID, + Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), + }) + } + newUnreadCounts.AddOne(dbEvt.UnreadType) } if isTimeline { if dbEvt.CanUseForPreview() { @@ -604,6 +600,9 @@ func (h *HiClient) processStateAndTimeline( for i := len(timeline.Events) - 1; i >= 0; i-- { if slices.Contains(newOwnReceipts, timeline.Events[i].ID) { readUpToIndex = i + // Reset unread counts if we see our own read receipt in the timeline. + // It'll be updated with new unreads (if any) at the end. + updatedRoom.UnreadCounts = database.UnreadCounts{} break } } @@ -671,6 +670,21 @@ func (h *HiClient) processStateAndTimeline( updatedRoom.Avatar = &dmAvatarURL } } + + if len(receipts) > 0 { + err = h.DB.Receipt.PutMany(ctx, room.ID, receipts...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) + } + } + if len(newOwnReceipts) > 0 && newUnreadCounts.IsZero() { + updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID) + if err != nil { + return fmt.Errorf("failed to recalculate unread counts: %w", err) + } + } else { + updatedRoom.UnreadCounts.Add(newUnreadCounts) + } if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) { updatedRoom.PrevBatch = timeline.PrevBatch }