gomuks/pkg/hicli/sync.go
Tulir Asokan f5eeb8461a hicli/sync: enable reaction count collection for all events
The reaction aggregation probably needs to be redone to support finding
the entire event (to see senders and other content like shortcodes),
but this is good enough to get reactions rendering.
2024-10-21 23:11:14 +03:00

914 lines
30 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exzerolog"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type syncContext struct {
shouldWakeupRequestQueue bool
evt *SyncComplete
}
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
log := zerolog.Ctx(ctx)
postponedToDevices := resp.ToDevice.Events[:0]
for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
log.Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Msg("Failed to parse to-device event, skipping")
continue
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
h.Crypto.HandleEncryptedEvent(ctx, evt)
case *event.RoomKeyWithheldEventContent:
h.Crypto.HandleRoomKeyWithheld(ctx, content)
default:
postponedToDevices = append(postponedToDevices, evt)
}
}
resp.ToDevice.Events = postponedToDevices
return nil
}
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
go h.asyncPostProcessSyncResponse(ctx, resp, since)
syncCtx := ctx.Value(syncContextKey).(*syncContext)
if syncCtx.shouldWakeupRequestQueue {
h.WakeupRequestQueue()
}
h.firstSyncReceived = true
if !syncCtx.evt.IsEmpty() {
h.EventHandler(syncCtx.evt)
}
}
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
for _, evt := range resp.ToDevice.Events {
switch content := evt.Content.Parsed.(type) {
case *event.SecretRequestEventContent:
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
case *event.RoomKeyRequestEventContent:
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
}
}
}
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
if len(resp.DeviceLists.Changed) > 0 {
zerolog.Ctx(ctx).Debug().
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
Msg("Marking changed device lists for tracked users as outdated")
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
if err != nil {
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
}
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
for _, evt := range resp.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
if evt.Type == event.AccountDataPushRules {
err = evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
} else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
h.receiveNewPushRules(ctx, pushRules.Ruleset)
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
}
}
}
for roomID, room := range resp.Rooms.Join {
err := h.processSyncJoinedRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
}
}
for roomID, room := range resp.Rooms.Leave {
err := h.processSyncLeftRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
}
}
h.Account.NextBatch = resp.NextBatch
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
if err != nil {
return fmt.Errorf("failed to save next_batch: %w", err)
}
return nil
}
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
receiptList := make([]*database.Receipt, 0)
var newOwnReceipts []id.EventID
for eventID, receipts := range *content {
for receiptType, users := range receipts {
for userID, receiptInfo := range users {
if userID == h.Account.UserID {
newOwnReceipts = append(newOwnReceipts, eventID)
}
receiptList = append(receiptList, &database.Receipt{
UserID: userID,
ReceiptType: receiptType,
ThreadID: receiptInfo.ThreadID,
EventID: eventID,
Timestamp: jsontime.UM(receiptInfo.Timestamp),
})
}
}
}
return receiptList, newOwnReceipts
}
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
err = h.DB.Room.CreateRow(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to ensure room row exists: %w", err)
}
existingRoomData = &database.Room{
ID: roomID,
// Hack to set a default value for SortingTimestamp which is before all existing rooms,
// but not the same for all rooms without a timestamp.
SortingTimestamp: jsontime.UM(time.UnixMilli(time.Now().Unix())),
}
}
for _, evt := range room.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
evt.RoomID = roomID
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
var receiptsList []*database.Receipt
var newOwnReceipts []id.EventID
for _, evt := range room.Ephemeral.Events {
evt.Type.Class = event.EphemeralEventType
err = evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
continue
}
switch evt.Type {
case event.EphemeralEventReceipt:
list, ownList := h.receiptsToList(evt.Content.AsReceipt())
receiptsList = append(receiptsList, list...)
newOwnReceipts = append(newOwnReceipts, ownList...)
case event.EphemeralEventTyping:
go h.EventHandler(&Typing{
RoomID: roomID,
TypingEventContent: *evt.Content.AsTyping(),
})
}
}
err = h.processStateAndTimeline(
ctx,
existingRoomData,
&room.State,
&room.Timeline,
&room.Summary,
receiptsList,
newOwnReceipts,
)
if err != nil {
return err
}
return nil
}
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
return nil
}
// TODO delete room
return nil
//return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
}
func isDecryptionErrorRetryable(err error) bool {
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
}
func removeReplyFallback(evt *event.Event) []byte {
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
return nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if ok && content.RelatesTo.GetReplyTo() != "" {
prevFormattedBody := content.FormattedBody
content.RemoveReplyFallback()
if content.FormattedBody != prevFormattedBody {
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
if err == nil && err2 == nil {
return bytes
}
}
}
return nil
}
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, nil, "", err
}
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil {
return nil, nil, "", err
}
withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil {
return decrypted, withoutFallback, decrypted.Type.Type, nil
}
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
}
func (h *HiClient) addMediaCache(
ctx context.Context,
eventRowID database.EventRowID,
uri id.ContentURIString,
file *event.EncryptedFileInfo,
info *event.FileInfo,
fileName string,
) {
parsedMXC := uri.ParseOrIgnore()
if !parsedMXC.IsValid() {
return
}
cm := &database.Media{
MXC: parsedMXC,
FileName: fileName,
}
if file != nil {
cm.EncFile = &file.EncryptedFile
}
if info != nil {
cm.MimeType = info.MimeType
}
err := h.DB.Media.Put(ctx, cm)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("mxc", parsedMXC).
Int64("event_rowid", int64(eventRowID)).
Msg("Failed to add database media entry")
return
}
err = h.DB.Media.AddReference(ctx, eventRowID, parsedMXC)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("mxc", parsedMXC).
Int64("event_rowid", int64(eventRowID)).
Msg("Failed to add database media reference")
}
}
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
switch evt.Type {
case event.EventMessage, event.EventSticker:
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return
}
if content.File != nil {
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
} else if content.URL != "" {
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
}
if content.GetInfo().ThumbnailFile != nil {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
} else if content.GetInfo().ThumbnailURL != "" {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
}
case event.StateRoomAvatar:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
case event.StateMember:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
}
}
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) {
if evt.Type != event.EventMessage {
return nil, nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return nil, nil
}
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
content = content.NewContent
}
if content != nil {
var sanitizedHTML string
var wasPlaintext bool
var inlineImages []id.ContentURI
if content.Format == event.FormatHTML && content.FormattedBody != "" {
var err error
sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_id", dbEvt.ID).
Msg("Failed to sanitize HTML")
}
if len(inlineImages) > 0 && dbEvt.RowID != 0 {
for _, uri := range inlineImages {
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
}
inlineImages = nil
}
} else {
var builder strings.Builder
linkifyAndWriteBytes(&builder, []byte(content.Body))
sanitizedHTML = builder.String()
wasPlaintext = true
}
return &database.LocalContent{
SanitizedHTML: sanitizedHTML,
HTMLVersion: CurrentHTMLSanitizerVersion,
WasPlaintext: wasPlaintext,
}, inlineImages
}
return nil, nil
}
const CurrentHTMLSanitizerVersion = 3
func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) {
if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
return
}
evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
err := h.DB.Event.UpdateLocalContent(ctx, evt)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", evt.ID).
Msg("Failed to update local content")
}
}
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) {
if dbEvt.RowID != 0 {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
if evt.Sender != h.Account.UserID {
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
}
dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt)
return
}
func (h *HiClient) processEvent(
ctx context.Context,
evt *event.Event,
llSummary *mautrix.LazyLoadSummary,
decryptionQueue map[id.SessionID]*database.SessionRequest,
checkDB bool,
) (*database.Event, error) {
if checkDB {
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
if err != nil {
return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
} else if dbEvt != nil {
return dbEvt, nil
}
}
dbEvt := database.MautrixToEvent(evt)
dbEvt.Reactions = make(map[string]int)
contentWithoutFallback := removeReplyFallback(evt)
if contentWithoutFallback != nil {
dbEvt.Content = contentWithoutFallback
}
var decryptionErr error
var decryptedMautrixEvt *event.Event
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
} else if evt.Type == event.EventRedaction {
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
var err error
evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
if err != nil {
return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
}
} else if evt.Redacts == "" {
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
}
}
var inlineImages []id.ContentURI
if decryptedMautrixEvt != nil {
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
} else {
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
}
_, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil {
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
}
if decryptedMautrixEvt != nil {
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
} else {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
for _, uri := range inlineImages {
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
}
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok {
req = &database.SessionRequest{
RoomID: evt.RoomID,
SessionID: dbEvt.MegolmSessionID,
Sender: evt.Sender,
}
}
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
req.MinIndex = min(uint32(minIndex), req.MinIndex)
if decryptionQueue != nil {
decryptionQueue[dbEvt.MegolmSessionID] = req
} else {
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("session_id", dbEvt.MegolmSessionID).
Msg("Failed to save session request")
} else {
h.WakeupRequestQueue()
}
}
}
return dbEvt, err
}
var unsetSortingTimestamp = time.UnixMilli(1000000000000)
func (h *HiClient) processStateAndTimeline(
ctx context.Context,
room *database.Room,
state *mautrix.SyncEventsList,
timeline *mautrix.SyncTimeline,
summary *mautrix.LazyLoadSummary,
receipts []*database.Receipt,
newOwnReceipts []id.EventID,
) error {
updatedRoom := &database.Room{
ID: room.ID,
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
UnreadCounts: room.UnreadCounts,
}
heroesChanged := false
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
summary = room.LazyLoadSummary
} else if room.LazyLoadSummary == nil ||
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
updatedRoom.LazyLoadSummary = summary
heroesChanged = true
}
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
newNotifications := make([]SyncNotification, 0)
var recalculatePreviewEvent, unreadMessagesWereMaybeRedacted bool
var newUnreadCounts database.UnreadCounts
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
if rowID != 0 {
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
} else {
dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
}
if err != nil {
return nil, fmt.Errorf("failed to get redaction target: %w", err)
} else if dbEvt == nil {
return nil, nil
}
allNewEvents = append(allNewEvents, dbEvt)
return dbEvt, nil
}
processRedaction := func(evt *event.Event) error {
dbEvt, err := addOldEvent(0, evt.Redacts)
if err != nil {
return fmt.Errorf("failed to get redaction target: %w", err)
}
if dbEvt == nil {
return nil
}
if dbEvt.UnreadType > 0 {
unreadMessagesWereMaybeRedacted = true
}
if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
_, err = addOldEvent(0, dbEvt.RelatesTo)
if err != nil {
return fmt.Errorf("failed to get relation target of redaction target: %w", err)
}
}
if updatedRoom.PreviewEventRowID == dbEvt.RowID {
updatedRoom.PreviewEventRowID = 0
recalculatePreviewEvent = true
}
return nil
}
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
evt.RoomID = room.ID
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false)
if err != nil {
return -1, err
}
if isUnread {
if dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
newNotifications = append(newNotifications, SyncNotification{
RowID: dbEvt.RowID,
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
})
}
newUnreadCounts.AddOne(dbEvt.UnreadType)
}
if isTimeline {
if dbEvt.CanUseForPreview() {
updatedRoom.PreviewEventRowID = dbEvt.RowID
recalculatePreviewEvent = false
}
updatedRoom.BumpSortingTimestamp(dbEvt)
}
if evt.StateKey != nil {
var membership event.Membership
if evt.Type == event.StateMember {
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
heroesChanged = true
}
} else if evt.Type == event.StateElementFunctionalMembers {
heroesChanged = true
}
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
if err != nil {
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
}
processImportantEvent(ctx, evt, room, updatedRoom)
}
allNewEvents = append(allNewEvents, dbEvt)
if evt.Type == event.EventRedaction && evt.Redacts != "" {
err = processRedaction(evt)
if err != nil {
return -1, fmt.Errorf("failed to process redaction: %w", err)
}
} else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
_, err = addOldEvent(0, dbEvt.RelatesTo)
if err != nil {
return -1, fmt.Errorf("failed to get relation target of event: %w", err)
}
}
return dbEvt.RowID, nil
}
changedState := make(map[event.Type]map[string]database.EventRowID)
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
if _, ok := changedState[evtType]; !ok {
changedState[evtType] = make(map[string]database.EventRowID)
}
changedState[evtType][stateKey] = rowID
}
for _, evt := range state.Events {
evt.Type.Class = event.StateEventType
rowID, err := processNewEvent(evt, false, false)
if err != nil {
return err
}
setNewState(evt.Type, *evt.StateKey, rowID)
}
var timelineRowTuples []database.TimelineRowTuple
var err error
if len(timeline.Events) > 0 {
timelineIDs := make([]database.EventRowID, len(timeline.Events))
readUpToIndex := -1
for i := len(timeline.Events) - 1; i >= 0; i-- {
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
readUpToIndex = i
// Reset unread counts if we see our own read receipt in the timeline.
// It'll be updated with new unreads (if any) at the end.
updatedRoom.UnreadCounts = database.UnreadCounts{}
break
}
}
for i, evt := range timeline.Events {
if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
evt.Type.Class = event.MessageEventType
}
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
if err != nil {
return err
}
if evt.StateKey != nil {
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
}
}
if updatedRoom.SortingTimestamp.Before(unsetSortingTimestamp) && len(timeline.Events) > 0 {
updatedRoom.SortingTimestamp = jsontime.UM(time.UnixMilli(timeline.Events[len(timeline.Events)-1].Timestamp))
}
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
if len(decryptionQueue) > 0 {
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
if timeline.Limited {
err = h.DB.Timeline.Clear(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to clear old timeline: %w", err)
}
updatedRoom.PrevBatch = timeline.PrevBatch
h.paginationInterrupterLock.Lock()
if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
interrupt(ErrTimelineReset)
}
h.paginationInterrupterLock.Unlock()
}
timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
if err != nil {
return fmt.Errorf("failed to append timeline: %w", err)
}
} else {
timelineRowTuples = make([]database.TimelineRowTuple, 0)
}
if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to recalculate preview event: %w", err)
}
_, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
if err != nil {
return fmt.Errorf("failed to get preview event: %w", err)
}
}
// Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
if err != nil {
return fmt.Errorf("failed to calculate room name: %w", err)
}
updatedRoom.Name = &name
updatedRoom.NameQuality = database.NameQualityParticipants
if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
updatedRoom.Avatar = &dmAvatarURL
}
}
if len(receipts) > 0 {
err = h.DB.Receipt.PutMany(ctx, room.ID, receipts...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
}
if !room.UnreadCounts.IsZero() && ((len(newOwnReceipts) > 0 && newUnreadCounts.IsZero()) || unreadMessagesWereMaybeRedacted) {
updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID)
if err != nil {
return fmt.Errorf("failed to recalculate unread counts: %w", err)
}
} else {
updatedRoom.UnreadCounts.Add(newUnreadCounts)
}
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
updatedRoom.PrevBatch = timeline.PrevBatch
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
// TODO why is *old* unread count sometimes zero when processing the read receipt that is making it zero?
if roomChanged || len(newOwnReceipts) > 0 || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
Meta: room,
Timeline: timelineRowTuples,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,
Notifications: newNotifications,
}
}
return nil
}
func joinMemberNames(names []string, totalCount int) string {
if len(names) == 1 {
return names[0]
} else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
} else {
return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
}
}
func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
var primaryAvatarURL id.ContentURI
if summary == nil || len(summary.Heroes) == 0 {
return "Empty room", primaryAvatarURL, nil
}
var functionalMembers []id.UserID
functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
if err != nil {
return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
} else if functionalMembersEvt != nil {
mautrixEvt := functionalMembersEvt.AsRawMautrix()
_ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
if ok {
functionalMembers = content.ServiceMembers
}
}
var members, leftMembers []string
var memberCount int
if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
memberCount = *summary.JoinedMemberCount
} else if summary.InvitedMemberCount != nil {
memberCount = *summary.InvitedMemberCount
}
for _, hero := range summary.Heroes {
if slices.Contains(functionalMembers, hero) {
memberCount--
continue
} else if len(members) >= 5 {
break
}
heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
if err != nil {
return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
} else if heroEvt == nil {
leftMembers = append(leftMembers, hero.String())
continue
}
membership := gjson.GetBytes(heroEvt.Content, "membership").Str
name := gjson.GetBytes(heroEvt.Content, "displayname").Str
if name == "" {
name = hero.String()
}
avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
if avatarURL != "" {
primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
}
if membership == "join" || membership == "invite" {
members = append(members, name)
} else {
leftMembers = append(leftMembers, name)
}
}
if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
primaryAvatarURL = id.ContentURI{}
}
if len(members) > 0 {
return joinMemberNames(members, memberCount), primaryAvatarURL, nil
} else if len(leftMembers) > 0 {
return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
} else {
return "Empty room", primaryAvatarURL, nil
}
}
func intPtrEqual(a, b *int) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
if evt.StateKey == nil {
return
}
switch evt.Type {
case event.StateCreate, event.StateTombstone, event.StateRoomName, event.StateCanonicalAlias,
event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
if *evt.StateKey != "" {
return
}
default:
return
}
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("event_id", evt.ID).
Msg("Failed to parse state event, skipping")
return
}
switch evt.Type {
case event.StateCreate:
updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
case event.StateTombstone:
updatedRoom.Tombstone, _ = evt.Content.Parsed.(*event.TombstoneEventContent)
case event.StateEncryption:
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
updatedRoom.EncryptionEvent = newEncryption
}
case event.StateRoomName:
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
if ok {
updatedRoom.Name = &content.Name
updatedRoom.NameQuality = database.NameQualityExplicit
if content.Name == "" {
if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
} else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
} else {
updatedRoom.NameQuality = database.NameQualityNil
}
}
}
case event.StateCanonicalAlias:
content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
if ok {
updatedRoom.CanonicalAlias = &content.Alias
if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
updatedRoom.Name = (*string)(&content.Alias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
if content.Alias == "" {
updatedRoom.NameQuality = database.NameQualityNil
}
}
}
case event.StateRoomAvatar:
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if ok {
url, _ := content.URL.Parse()
updatedRoom.Avatar = &url
updatedRoom.ExplicitAvatar = true
}
case event.StateTopic:
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
if ok {
updatedRoom.Topic = &content.Topic
}
}
return
}