hicli: calculate unreads locally

This commit is contained in:
Tulir Asokan 2024-10-17 21:49:57 +03:00
parent 504e2bd976
commit 0455ff3d24
5 changed files with 136 additions and 53 deletions

View file

@ -284,20 +284,6 @@ type LocalContent struct {
HTMLVersion int `json:"html_version,omitempty"`
}
type UnreadType int
func (ut UnreadType) Is(flag UnreadType) bool {
return ut&flag != 0
}
const (
UnreadTypeNone UnreadType = 0b0000
UnreadTypeNormal UnreadType = 0b0001
UnreadTypeNotify UnreadType = 0b0010
UnreadTypeHighlight UnreadType = 0b0100
UnreadTypeSound UnreadType = 0b1000
)
type Event struct {
RowID EventRowID `json:"rowid"`
TimelineRowID TimelineRowID `json:"timeline_rowid"`

View file

@ -145,11 +145,9 @@ type Room struct {
EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
HasMemberList bool `json:"has_member_list"`
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
UnreadHighlights int `json:"unread_highlights"`
UnreadNotifications int `json:"unread_notifications"`
UnreadMessages int `json:"unread_messages"`
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
UnreadCounts
PrevBatch string `json:"prev_batch"`
}

View file

@ -0,0 +1,84 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"maunium.net/go/mautrix/id"
)
const (
// TODO find out if this needs to be wrapped in another query that limits the number of events it evaluates
// (or maybe the timeline store just shouldn't be allowed to grow that big?)
calculateUnreadsQuery = `
SELECT
COALESCE(SUM(CASE WHEN unread_type & 0100 THEN 1 ELSE 0 END), 0) AS highlights,
COALESCE(SUM(CASE WHEN unread_type & 0010 THEN 1 ELSE 0 END), 0) AS notifications,
COALESCE(SUM(CASE WHEN unread_type & 0001 THEN 1 ELSE 0 END), 0) AS messages
FROM timeline
JOIN event ON event.rowid = timeline.event_rowid
WHERE timeline.room_id = $1 AND timeline.rowid > (
SELECT MAX(rowid)
FROM timeline
WHERE room_id = $1 AND event_rowid IN (
SELECT event.rowid
FROM receipt
JOIN event ON receipt.event_id=event.event_id
WHERE receipt.room_id = $1 AND receipt.user_id = $2
)
)
`
)
func (rq *RoomQuery) CalculateUnreads(ctx context.Context, roomID id.RoomID, userID id.UserID) (uc UnreadCounts, err error) {
err = rq.GetDB().QueryRow(ctx, calculateUnreadsQuery, roomID, userID).
Scan(&uc.UnreadHighlights, &uc.UnreadNotifications, &uc.UnreadMessages)
return
}
type UnreadType int
func (ut UnreadType) Is(flag UnreadType) bool {
return ut&flag != 0
}
const (
UnreadTypeNone UnreadType = 0b0000
UnreadTypeNormal UnreadType = 0b0001
UnreadTypeNotify UnreadType = 0b0010
UnreadTypeHighlight UnreadType = 0b0100
UnreadTypeSound UnreadType = 0b1000
)
type UnreadCounts struct {
UnreadHighlights int `json:"unread_highlights"`
UnreadNotifications int `json:"unread_notifications"`
UnreadMessages int `json:"unread_messages"`
}
func (uc *UnreadCounts) IsZero() bool {
return uc.UnreadHighlights == 0 && uc.UnreadNotifications == 0 && uc.UnreadMessages == 0
}
func (uc *UnreadCounts) Add(other UnreadCounts) {
uc.UnreadHighlights += other.UnreadHighlights
uc.UnreadNotifications += other.UnreadNotifications
uc.UnreadMessages += other.UnreadMessages
}
func (uc *UnreadCounts) AddOne(ut UnreadType) {
if ut.Is(UnreadTypeNormal) {
uc.UnreadMessages++
}
if ut.Is(UnreadTypeNotify) {
uc.UnreadNotifications++
}
if ut.Is(UnreadTypeHighlight) {
uc.UnreadHighlights++
}
}

View file

@ -133,6 +133,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ
Reactions: map[string]int{},
LastEditRowID: ptr.Ptr(database.EventRowID(0)),
}
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix())
_, err = h.DB.Event.Insert(ctx, dbEvt)
if err != nil {
return nil, fmt.Errorf("failed to insert event into database: %w", err)

View file

@ -156,11 +156,6 @@ func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*databa
return receiptList, newOwnReceipts
}
type receiptsToSave struct {
roomID id.RoomID
receipts []*database.Receipt
}
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
@ -181,7 +176,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
var receipts []receiptsToSave
var receiptsList []*database.Receipt
var newOwnReceipts []id.EventID
for _, evt := range room.Ephemeral.Events {
evt.Type.Class = event.EphemeralEventType
@ -192,9 +187,9 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
}
switch evt.Type {
case event.EphemeralEventReceipt:
var receiptsList []*database.Receipt
receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt())
receipts = append(receipts, receiptsToSave{roomID, receiptsList})
list, ownList := h.receiptsToList(evt.Content.AsReceipt())
receiptsList = append(receiptsList, list...)
newOwnReceipts = append(newOwnReceipts, ownList...)
case event.EphemeralEventTyping:
go h.EventHandler(&Typing{
RoomID: roomID,
@ -202,16 +197,18 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID,
})
}
}
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications)
err = h.processStateAndTimeline(
ctx,
existingRoomData,
&room.State,
&room.Timeline,
&room.Summary,
receiptsList,
newOwnReceipts,
)
if err != nil {
return err
}
for _, rs := range receipts {
err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
}
return nil
}
@ -222,8 +219,9 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro
} else if existingRoomData == nil {
return nil
}
// TODO delete room instead of processing?
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
// TODO delete room
return nil
//return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
}
func isDecryptionErrorRetryable(err error) bool {
@ -333,7 +331,7 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab
}
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
if evt.Type != event.EventMessage {
return nil
}
_ = evt.Content.ParseRaw(evt.Type)
@ -468,21 +466,15 @@ func (h *HiClient) processStateAndTimeline(
state *mautrix.SyncEventsList,
timeline *mautrix.SyncTimeline,
summary *mautrix.LazyLoadSummary,
receipts []*database.Receipt,
newOwnReceipts []id.EventID,
serverNotificationCounts *mautrix.UnreadNotificationCounts,
) error {
updatedRoom := &database.Room{
ID: room.ID,
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
UnreadHighlights: room.UnreadHighlights,
UnreadNotifications: room.UnreadNotifications,
UnreadMessages: room.UnreadMessages,
}
if serverNotificationCounts != nil {
updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount
updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
UnreadCounts: room.UnreadCounts,
}
heroesChanged := false
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
@ -498,6 +490,7 @@ func (h *HiClient) processStateAndTimeline(
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
newNotifications := make([]SyncNotification, 0)
recalculatePreviewEvent := false
var newUnreadCounts database.UnreadCounts
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
if rowID != 0 {
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
@ -538,11 +531,14 @@ func (h *HiClient) processStateAndTimeline(
if err != nil {
return -1, err
}
if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
newNotifications = append(newNotifications, SyncNotification{
RowID: dbEvt.RowID,
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
})
if isUnread {
if dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
newNotifications = append(newNotifications, SyncNotification{
RowID: dbEvt.RowID,
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
})
}
newUnreadCounts.AddOne(dbEvt.UnreadType)
}
if isTimeline {
if dbEvt.CanUseForPreview() {
@ -604,6 +600,9 @@ func (h *HiClient) processStateAndTimeline(
for i := len(timeline.Events) - 1; i >= 0; i-- {
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
readUpToIndex = i
// Reset unread counts if we see our own read receipt in the timeline.
// It'll be updated with new unreads (if any) at the end.
updatedRoom.UnreadCounts = database.UnreadCounts{}
break
}
}
@ -671,6 +670,21 @@ func (h *HiClient) processStateAndTimeline(
updatedRoom.Avatar = &dmAvatarURL
}
}
if len(receipts) > 0 {
err = h.DB.Receipt.PutMany(ctx, room.ID, receipts...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
}
if len(newOwnReceipts) > 0 && newUnreadCounts.IsZero() {
updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID)
if err != nil {
return fmt.Errorf("failed to recalculate unread counts: %w", err)
}
} else {
updatedRoom.UnreadCounts.Add(newUnreadCounts)
}
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
updatedRoom.PrevBatch = timeline.PrevBatch
}