forked from Mirrors/gomuks

Edits will now use a different HTML -> markdown converter than what is used to generate the body. This allows the plaintext body to have a plain shortcode for custom emojis, while still having the raw data for edits. Additionally, for sent events, the raw input is saved locally, which allows preserving commands and other such things. A future extension may store the raw input in a custom field in the Matrix event to allow lossless edits of messages sent from other clients.
972 lines
33 KiB
Go
972 lines
33 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package hicli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"go.mau.fi/util/emojirunes"
|
|
"go.mau.fi/util/exzerolog"
|
|
"go.mau.fi/util/jsontime"
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/crypto/olm"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/format"
|
|
"maunium.net/go/mautrix/id"
|
|
"maunium.net/go/mautrix/pushrules"
|
|
|
|
"go.mau.fi/gomuks/pkg/hicli/database"
|
|
)
|
|
|
|
type syncContext struct {
|
|
shouldWakeupRequestQueue bool
|
|
|
|
evt *SyncComplete
|
|
}
|
|
|
|
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
|
log := zerolog.Ctx(ctx)
|
|
postponedToDevices := resp.ToDevice.Events[:0]
|
|
for _, evt := range resp.ToDevice.Events {
|
|
evt.Type.Class = event.ToDeviceEventType
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
log.Warn().Err(err).
|
|
Stringer("event_type", &evt.Type).
|
|
Stringer("sender", evt.Sender).
|
|
Msg("Failed to parse to-device event, skipping")
|
|
continue
|
|
}
|
|
|
|
switch content := evt.Content.Parsed.(type) {
|
|
case *event.EncryptedEventContent:
|
|
h.Crypto.HandleEncryptedEvent(ctx, evt)
|
|
case *event.RoomKeyWithheldEventContent:
|
|
h.Crypto.HandleRoomKeyWithheld(ctx, content)
|
|
default:
|
|
postponedToDevices = append(postponedToDevices, evt)
|
|
}
|
|
}
|
|
resp.ToDevice.Events = postponedToDevices
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
|
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
|
go h.asyncPostProcessSyncResponse(ctx, resp, since)
|
|
syncCtx := ctx.Value(syncContextKey).(*syncContext)
|
|
if syncCtx.shouldWakeupRequestQueue {
|
|
h.WakeupRequestQueue()
|
|
}
|
|
if !h.firstSyncReceived {
|
|
h.firstSyncReceived = true
|
|
h.Client.Client.Transport.(*http.Transport).ResponseHeaderTimeout = 60 * time.Second
|
|
h.Client.Client.Timeout = 180 * time.Second
|
|
}
|
|
if !syncCtx.evt.IsEmpty() {
|
|
h.EventHandler(syncCtx.evt)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
|
for _, evt := range resp.ToDevice.Events {
|
|
switch content := evt.Content.Parsed.(type) {
|
|
case *event.SecretRequestEventContent:
|
|
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
|
|
case *event.RoomKeyRequestEventContent:
|
|
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
|
if len(resp.DeviceLists.Changed) > 0 {
|
|
zerolog.Ctx(ctx).Debug().
|
|
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
|
|
Msg("Marking changed device lists for tracked users as outdated")
|
|
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
|
|
}
|
|
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
|
}
|
|
|
|
accountData := make(map[event.Type]*database.AccountData, len(resp.AccountData.Events))
|
|
var err error
|
|
for _, evt := range resp.AccountData.Events {
|
|
evt.Type.Class = event.AccountDataEventType
|
|
accountData[evt.Type], err = h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
|
}
|
|
if evt.Type == event.AccountDataPushRules {
|
|
err = evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
|
|
} else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
|
|
h.receiveNewPushRules(ctx, pushRules.Ruleset)
|
|
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
|
|
}
|
|
}
|
|
}
|
|
ctx.Value(syncContextKey).(*syncContext).evt.AccountData = accountData
|
|
for roomID, room := range resp.Rooms.Join {
|
|
err := h.processSyncJoinedRoom(ctx, roomID, room)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
|
|
}
|
|
}
|
|
for roomID, room := range resp.Rooms.Leave {
|
|
err := h.processSyncLeftRoom(ctx, roomID, room)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
|
|
}
|
|
}
|
|
h.Account.NextBatch = resp.NextBatch
|
|
err = h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save next_batch: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
|
|
receiptList := make([]*database.Receipt, 0)
|
|
var newOwnReceipts []id.EventID
|
|
for eventID, receipts := range *content {
|
|
for receiptType, users := range receipts {
|
|
for userID, receiptInfo := range users {
|
|
if userID == h.Account.UserID {
|
|
newOwnReceipts = append(newOwnReceipts, eventID)
|
|
}
|
|
receiptList = append(receiptList, &database.Receipt{
|
|
UserID: userID,
|
|
ReceiptType: receiptType,
|
|
ThreadID: receiptInfo.ThreadID,
|
|
EventID: eventID,
|
|
Timestamp: jsontime.UM(receiptInfo.Timestamp),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
return receiptList, newOwnReceipts
|
|
}
|
|
|
|
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
|
|
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get room data: %w", err)
|
|
} else if existingRoomData == nil {
|
|
err = h.DB.Room.CreateRow(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to ensure room row exists: %w", err)
|
|
}
|
|
existingRoomData = &database.Room{
|
|
ID: roomID,
|
|
// Hack to set a default value for SortingTimestamp which is before all existing rooms,
|
|
// but not the same for all rooms without a timestamp.
|
|
SortingTimestamp: jsontime.UM(time.UnixMilli(time.Now().Unix())),
|
|
}
|
|
}
|
|
|
|
accountData := make(map[event.Type]*database.AccountData, len(room.AccountData.Events))
|
|
for _, evt := range room.AccountData.Events {
|
|
evt.Type.Class = event.AccountDataEventType
|
|
evt.RoomID = roomID
|
|
accountData[evt.Type], err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
|
}
|
|
}
|
|
var receiptsList []*database.Receipt
|
|
var newOwnReceipts []id.EventID
|
|
for _, evt := range room.Ephemeral.Events {
|
|
evt.Type.Class = event.EphemeralEventType
|
|
err = evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
|
|
continue
|
|
}
|
|
switch evt.Type {
|
|
case event.EphemeralEventReceipt:
|
|
list, ownList := h.receiptsToList(evt.Content.AsReceipt())
|
|
receiptsList = append(receiptsList, list...)
|
|
newOwnReceipts = append(newOwnReceipts, ownList...)
|
|
case event.EphemeralEventTyping:
|
|
go h.EventHandler(&Typing{
|
|
RoomID: roomID,
|
|
TypingEventContent: *evt.Content.AsTyping(),
|
|
})
|
|
}
|
|
}
|
|
err = h.processStateAndTimeline(
|
|
ctx,
|
|
existingRoomData,
|
|
&room.State,
|
|
&room.Timeline,
|
|
&room.Summary,
|
|
receiptsList,
|
|
newOwnReceipts,
|
|
accountData,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
|
|
zerolog.Ctx(ctx).Debug().Stringer("room_id", roomID).Msg("Deleting left room")
|
|
err := h.DB.Room.Delete(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete room: %w", err)
|
|
}
|
|
payload := ctx.Value(syncContextKey).(*syncContext).evt
|
|
payload.LeftRooms = append(payload.LeftRooms, roomID)
|
|
return nil
|
|
}
|
|
|
|
func isDecryptionErrorRetryable(err error) bool {
|
|
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
|
|
}
|
|
|
|
func removeReplyFallback(evt *event.Event) []byte {
|
|
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
|
|
return nil
|
|
}
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if ok && content.RelatesTo.GetReplyTo() != "" {
|
|
prevFormattedBody := content.FormattedBody
|
|
content.RemoveReplyFallback()
|
|
if content.FormattedBody != prevFormattedBody {
|
|
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
|
|
bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
|
|
if err == nil && err2 == nil {
|
|
return bytes
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
return nil, nil, "", err
|
|
}
|
|
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
withoutFallback := removeReplyFallback(decrypted)
|
|
if withoutFallback != nil {
|
|
return decrypted, withoutFallback, decrypted.Type.Type, nil
|
|
}
|
|
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
|
}
|
|
|
|
func (h *HiClient) addMediaCache(
|
|
ctx context.Context,
|
|
eventRowID database.EventRowID,
|
|
uri id.ContentURIString,
|
|
file *event.EncryptedFileInfo,
|
|
info *event.FileInfo,
|
|
fileName string,
|
|
) {
|
|
parsedMXC := uri.ParseOrIgnore()
|
|
if !parsedMXC.IsValid() {
|
|
return
|
|
}
|
|
cm := &database.Media{
|
|
MXC: parsedMXC,
|
|
FileName: fileName,
|
|
}
|
|
if file != nil {
|
|
cm.EncFile = &file.EncryptedFile
|
|
}
|
|
if info != nil {
|
|
cm.MimeType = info.MimeType
|
|
}
|
|
err := h.DB.Media.Put(ctx, cm)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("mxc", parsedMXC).
|
|
Int64("event_rowid", int64(eventRowID)).
|
|
Msg("Failed to add database media entry")
|
|
return
|
|
}
|
|
err = h.DB.Media.AddReference(ctx, eventRowID, parsedMXC)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("mxc", parsedMXC).
|
|
Int64("event_rowid", int64(eventRowID)).
|
|
Msg("Failed to add database media reference")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
|
|
switch evt.Type {
|
|
case event.EventMessage, event.EventSticker:
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
if content.File != nil {
|
|
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
|
|
} else if content.URL != "" {
|
|
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
|
|
}
|
|
if content.GetInfo().ThumbnailFile != nil {
|
|
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
|
|
} else if content.GetInfo().ThumbnailURL != "" {
|
|
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
|
|
}
|
|
case event.StateRoomAvatar:
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
|
|
case event.StateMember:
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) {
|
|
if evt.Type != event.EventMessage {
|
|
return nil, nil
|
|
}
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
|
|
content = content.NewContent
|
|
}
|
|
if content != nil {
|
|
var sanitizedHTML, editSource string
|
|
var wasPlaintext, hasMath, bigEmoji bool
|
|
var inlineImages []id.ContentURI
|
|
if content.Format == event.FormatHTML && content.FormattedBody != "" {
|
|
var err error
|
|
sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("event_id", dbEvt.ID).
|
|
Msg("Failed to sanitize HTML")
|
|
}
|
|
hasMath = strings.Contains(sanitizedHTML, "<hicli-math")
|
|
if len(inlineImages) > 0 && dbEvt.RowID != 0 {
|
|
for _, uri := range inlineImages {
|
|
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
|
|
}
|
|
inlineImages = nil
|
|
}
|
|
if dbEvt.LocalContent != nil && dbEvt.LocalContent.EditSource != "" {
|
|
editSource = dbEvt.LocalContent.EditSource
|
|
} else if evt.Sender == h.Account.UserID {
|
|
editSource, _ = format.HTMLToMarkdownFull(htmlToMarkdownForInput, content.FormattedBody)
|
|
if content.MsgType == event.MsgEmote {
|
|
editSource = "/me " + editSource
|
|
} else if content.MsgType == event.MsgNotice {
|
|
editSource = "/notice " + editSource
|
|
}
|
|
}
|
|
} else {
|
|
hasSpecialCharacters := false
|
|
for _, char := range content.Body {
|
|
if char == '<' || char == '>' || char == '&' || char == '.' || char == ':' {
|
|
hasSpecialCharacters = true
|
|
break
|
|
}
|
|
}
|
|
if hasSpecialCharacters {
|
|
var builder strings.Builder
|
|
builder.Grow(len(content.Body) + builderPreallocBuffer)
|
|
linkifyAndWriteBytes(&builder, []byte(content.Body))
|
|
sanitizedHTML = builder.String()
|
|
} else if len(content.Body) < 100 && emojirunes.IsOnlyEmojis(content.Body) {
|
|
bigEmoji = true
|
|
}
|
|
if content.MsgType == event.MsgEmote {
|
|
editSource = "/me " + content.Body
|
|
} else if content.MsgType == event.MsgNotice {
|
|
editSource = "/notice " + content.Body
|
|
}
|
|
wasPlaintext = true
|
|
}
|
|
return &database.LocalContent{
|
|
SanitizedHTML: sanitizedHTML,
|
|
HTMLVersion: CurrentHTMLSanitizerVersion,
|
|
WasPlaintext: wasPlaintext,
|
|
BigEmoji: bigEmoji,
|
|
HasMath: hasMath,
|
|
EditSource: editSource,
|
|
}, inlineImages
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
const CurrentHTMLSanitizerVersion = 6
|
|
|
|
func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) {
|
|
if (evt.Type != event.EventMessage.Type && evt.DecryptedType != event.EventMessage.Type) ||
|
|
evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
|
|
return
|
|
}
|
|
evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
|
|
err := h.DB.Event.UpdateLocalContent(ctx, evt)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).
|
|
Stringer("event_id", evt.ID).
|
|
Msg("Failed to update local content")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) {
|
|
if dbEvt.RowID != 0 {
|
|
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
|
}
|
|
if evt.Sender != h.Account.UserID {
|
|
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
|
|
}
|
|
dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt)
|
|
return
|
|
}
|
|
|
|
func (h *HiClient) processEvent(
|
|
ctx context.Context,
|
|
evt *event.Event,
|
|
llSummary *mautrix.LazyLoadSummary,
|
|
decryptionQueue map[id.SessionID]*database.SessionRequest,
|
|
checkDB bool,
|
|
) (*database.Event, error) {
|
|
if checkDB {
|
|
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
|
|
} else if dbEvt != nil {
|
|
return dbEvt, nil
|
|
}
|
|
}
|
|
dbEvt := database.MautrixToEvent(evt)
|
|
dbEvt.Reactions = make(map[string]int)
|
|
contentWithoutFallback := removeReplyFallback(evt)
|
|
if contentWithoutFallback != nil {
|
|
dbEvt.Content = contentWithoutFallback
|
|
}
|
|
var decryptionErr error
|
|
var decryptedMautrixEvt *event.Event
|
|
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
|
|
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
|
if decryptionErr != nil {
|
|
dbEvt.DecryptionError = decryptionErr.Error()
|
|
}
|
|
} else if evt.Type == event.EventRedaction {
|
|
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
|
|
var err error
|
|
evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
|
|
if err != nil {
|
|
return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
|
|
}
|
|
} else if evt.Redacts == "" {
|
|
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
|
|
}
|
|
}
|
|
var inlineImages []id.ContentURI
|
|
if decryptedMautrixEvt != nil {
|
|
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
|
|
} else {
|
|
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
|
|
}
|
|
_, err := h.DB.Event.Upsert(ctx, dbEvt)
|
|
if err != nil {
|
|
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
|
}
|
|
if decryptedMautrixEvt != nil {
|
|
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
|
|
} else {
|
|
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
|
}
|
|
for _, uri := range inlineImages {
|
|
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
|
|
}
|
|
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
|
|
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
|
|
if !ok {
|
|
req = &database.SessionRequest{
|
|
RoomID: evt.RoomID,
|
|
SessionID: dbEvt.MegolmSessionID,
|
|
Sender: evt.Sender,
|
|
}
|
|
}
|
|
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
|
|
req.MinIndex = min(uint32(minIndex), req.MinIndex)
|
|
if decryptionQueue != nil {
|
|
decryptionQueue[dbEvt.MegolmSessionID] = req
|
|
} else {
|
|
err = h.DB.SessionRequest.Put(ctx, req)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).
|
|
Stringer("session_id", dbEvt.MegolmSessionID).
|
|
Msg("Failed to save session request")
|
|
} else {
|
|
h.WakeupRequestQueue()
|
|
}
|
|
}
|
|
}
|
|
return dbEvt, err
|
|
}
|
|
|
|
var unsetSortingTimestamp = time.UnixMilli(1000000000000)
|
|
|
|
func (h *HiClient) processStateAndTimeline(
|
|
ctx context.Context,
|
|
room *database.Room,
|
|
state *mautrix.SyncEventsList,
|
|
timeline *mautrix.SyncTimeline,
|
|
summary *mautrix.LazyLoadSummary,
|
|
receipts []*database.Receipt,
|
|
newOwnReceipts []id.EventID,
|
|
accountData map[event.Type]*database.AccountData,
|
|
) error {
|
|
updatedRoom := &database.Room{
|
|
ID: room.ID,
|
|
|
|
SortingTimestamp: room.SortingTimestamp,
|
|
NameQuality: room.NameQuality,
|
|
UnreadCounts: room.UnreadCounts,
|
|
}
|
|
heroesChanged := false
|
|
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
|
|
summary = room.LazyLoadSummary
|
|
} else if room.LazyLoadSummary == nil ||
|
|
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
|
|
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
|
|
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
|
|
updatedRoom.LazyLoadSummary = summary
|
|
heroesChanged = true
|
|
}
|
|
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
|
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
|
|
newNotifications := make([]SyncNotification, 0)
|
|
var recalculatePreviewEvent, unreadMessagesWereMaybeRedacted bool
|
|
var newUnreadCounts database.UnreadCounts
|
|
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
|
|
if rowID != 0 {
|
|
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
|
|
} else {
|
|
dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get redaction target: %w", err)
|
|
} else if dbEvt == nil {
|
|
return nil, nil
|
|
}
|
|
allNewEvents = append(allNewEvents, dbEvt)
|
|
return dbEvt, nil
|
|
}
|
|
processRedaction := func(evt *event.Event) error {
|
|
dbEvt, err := addOldEvent(0, evt.Redacts)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get redaction target: %w", err)
|
|
}
|
|
if dbEvt == nil {
|
|
return nil
|
|
}
|
|
if dbEvt.UnreadType > 0 {
|
|
unreadMessagesWereMaybeRedacted = true
|
|
}
|
|
if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
|
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get relation target of redaction target: %w", err)
|
|
}
|
|
}
|
|
if updatedRoom.PreviewEventRowID == dbEvt.RowID {
|
|
updatedRoom.PreviewEventRowID = 0
|
|
recalculatePreviewEvent = true
|
|
}
|
|
return nil
|
|
}
|
|
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
|
|
evt.RoomID = room.ID
|
|
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, evt.Unsigned.TransactionID != "")
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
if isUnread {
|
|
if dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
|
|
newNotifications = append(newNotifications, SyncNotification{
|
|
RowID: dbEvt.RowID,
|
|
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
|
|
})
|
|
}
|
|
newUnreadCounts.AddOne(dbEvt.UnreadType)
|
|
}
|
|
if isTimeline {
|
|
if dbEvt.CanUseForPreview() {
|
|
updatedRoom.PreviewEventRowID = dbEvt.RowID
|
|
recalculatePreviewEvent = false
|
|
}
|
|
updatedRoom.BumpSortingTimestamp(dbEvt)
|
|
}
|
|
if evt.StateKey != nil {
|
|
var membership event.Membership
|
|
if evt.Type == event.StateMember {
|
|
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
|
|
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
|
|
heroesChanged = true
|
|
}
|
|
} else if evt.Type == event.StateElementFunctionalMembers {
|
|
heroesChanged = true
|
|
}
|
|
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
|
|
}
|
|
processImportantEvent(ctx, evt, room, updatedRoom)
|
|
}
|
|
allNewEvents = append(allNewEvents, dbEvt)
|
|
if evt.Type == event.EventRedaction && evt.Redacts != "" {
|
|
err = processRedaction(evt)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to process redaction: %w", err)
|
|
}
|
|
} else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
|
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to get relation target of event: %w", err)
|
|
}
|
|
}
|
|
return dbEvt.RowID, nil
|
|
}
|
|
changedState := make(map[event.Type]map[string]database.EventRowID)
|
|
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
|
|
if _, ok := changedState[evtType]; !ok {
|
|
changedState[evtType] = make(map[string]database.EventRowID)
|
|
}
|
|
changedState[evtType][stateKey] = rowID
|
|
}
|
|
for _, evt := range state.Events {
|
|
evt.Type.Class = event.StateEventType
|
|
rowID, err := processNewEvent(evt, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setNewState(evt.Type, *evt.StateKey, rowID)
|
|
}
|
|
var timelineRowTuples []database.TimelineRowTuple
|
|
var err error
|
|
if len(timeline.Events) > 0 {
|
|
timelineIDs := make([]database.EventRowID, len(timeline.Events))
|
|
readUpToIndex := -1
|
|
for i := len(timeline.Events) - 1; i >= 0; i-- {
|
|
evt := timeline.Events[i]
|
|
isRead := slices.Contains(newOwnReceipts, evt.ID)
|
|
isOwnEvent := evt.Sender == h.Account.UserID
|
|
if isRead || isOwnEvent {
|
|
readUpToIndex = i
|
|
// Reset unread counts if we see our own read receipt in the timeline.
|
|
// It'll be updated with new unreads (if any) at the end.
|
|
updatedRoom.UnreadCounts = database.UnreadCounts{}
|
|
if !isRead {
|
|
receipts = append(receipts, &database.Receipt{
|
|
RoomID: room.ID,
|
|
UserID: h.Account.UserID,
|
|
ReceiptType: event.ReceiptTypeRead,
|
|
EventID: evt.ID,
|
|
Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
|
|
})
|
|
newOwnReceipts = append(newOwnReceipts, evt.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
for i, evt := range timeline.Events {
|
|
if evt.StateKey != nil {
|
|
evt.Type.Class = event.StateEventType
|
|
} else {
|
|
evt.Type.Class = event.MessageEventType
|
|
}
|
|
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if evt.StateKey != nil {
|
|
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
|
|
}
|
|
}
|
|
if updatedRoom.SortingTimestamp.Before(unsetSortingTimestamp) && len(timeline.Events) > 0 {
|
|
updatedRoom.SortingTimestamp = jsontime.UM(time.UnixMilli(timeline.Events[len(timeline.Events)-1].Timestamp))
|
|
}
|
|
for _, entry := range decryptionQueue {
|
|
err = h.DB.SessionRequest.Put(ctx, entry)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
|
}
|
|
}
|
|
if len(decryptionQueue) > 0 {
|
|
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
|
}
|
|
if timeline.Limited {
|
|
err = h.DB.Timeline.Clear(ctx, room.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to clear old timeline: %w", err)
|
|
}
|
|
updatedRoom.PrevBatch = timeline.PrevBatch
|
|
h.paginationInterrupterLock.Lock()
|
|
if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
|
|
interrupt(ErrTimelineReset)
|
|
}
|
|
h.paginationInterrupterLock.Unlock()
|
|
}
|
|
timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to append timeline: %w", err)
|
|
}
|
|
} else {
|
|
timelineRowTuples = make([]database.TimelineRowTuple, 0)
|
|
}
|
|
if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
|
|
updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to recalculate preview event: %w", err)
|
|
}
|
|
_, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get preview event: %w", err)
|
|
}
|
|
}
|
|
// Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
|
|
if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
|
|
name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to calculate room name: %w", err)
|
|
}
|
|
updatedRoom.Name = &name
|
|
updatedRoom.NameQuality = database.NameQualityParticipants
|
|
if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
|
|
updatedRoom.Avatar = &dmAvatarURL
|
|
}
|
|
}
|
|
|
|
if len(receipts) > 0 {
|
|
err = h.DB.Receipt.PutMany(ctx, room.ID, receipts...)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save receipts: %w", err)
|
|
}
|
|
}
|
|
if !room.UnreadCounts.IsZero() && ((len(newOwnReceipts) > 0 && newUnreadCounts.IsZero()) || unreadMessagesWereMaybeRedacted) {
|
|
updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to recalculate unread counts: %w", err)
|
|
}
|
|
} else {
|
|
updatedRoom.UnreadCounts.Add(newUnreadCounts)
|
|
}
|
|
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
|
|
updatedRoom.PrevBatch = timeline.PrevBatch
|
|
}
|
|
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
|
if roomChanged {
|
|
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save room data: %w", err)
|
|
}
|
|
}
|
|
// TODO why is *old* unread count sometimes zero when processing the read receipt that is making it zero?
|
|
if roomChanged || len(accountData) > 0 || len(newOwnReceipts) > 0 || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
|
|
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
|
|
Meta: room,
|
|
Timeline: timelineRowTuples,
|
|
AccountData: accountData,
|
|
State: changedState,
|
|
Reset: timeline.Limited,
|
|
Events: allNewEvents,
|
|
Notifications: newNotifications,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func joinMemberNames(names []string, totalCount int) string {
|
|
if len(names) == 1 {
|
|
return names[0]
|
|
} else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
|
|
return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
|
|
} else {
|
|
return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
|
|
var primaryAvatarURL id.ContentURI
|
|
if summary == nil || len(summary.Heroes) == 0 {
|
|
return "Empty room", primaryAvatarURL, nil
|
|
}
|
|
var functionalMembers []id.UserID
|
|
functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
|
|
if err != nil {
|
|
return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
|
|
} else if functionalMembersEvt != nil {
|
|
mautrixEvt := functionalMembersEvt.AsRawMautrix()
|
|
_ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
|
|
content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
|
|
if ok {
|
|
functionalMembers = content.ServiceMembers
|
|
}
|
|
}
|
|
var members, leftMembers []string
|
|
var memberCount int
|
|
if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
|
|
memberCount = *summary.JoinedMemberCount
|
|
} else if summary.InvitedMemberCount != nil {
|
|
memberCount = *summary.InvitedMemberCount
|
|
}
|
|
for _, hero := range summary.Heroes {
|
|
if slices.Contains(functionalMembers, hero) {
|
|
memberCount--
|
|
continue
|
|
} else if len(members) >= 5 {
|
|
break
|
|
}
|
|
heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
|
|
if err != nil {
|
|
return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
|
|
} else if heroEvt == nil {
|
|
leftMembers = append(leftMembers, hero.String())
|
|
continue
|
|
}
|
|
membership := gjson.GetBytes(heroEvt.Content, "membership").Str
|
|
name := gjson.GetBytes(heroEvt.Content, "displayname").Str
|
|
if name == "" {
|
|
name = hero.String()
|
|
}
|
|
avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
|
|
if avatarURL != "" {
|
|
primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
|
|
}
|
|
if membership == "join" || membership == "invite" {
|
|
members = append(members, name)
|
|
} else {
|
|
leftMembers = append(leftMembers, name)
|
|
}
|
|
}
|
|
if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
|
|
primaryAvatarURL = id.ContentURI{}
|
|
}
|
|
if len(members) > 0 {
|
|
return joinMemberNames(members, memberCount), primaryAvatarURL, nil
|
|
} else if len(leftMembers) > 0 {
|
|
return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
|
|
} else {
|
|
return "Empty room", primaryAvatarURL, nil
|
|
}
|
|
}
|
|
|
|
func intPtrEqual(a, b *int) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
return *a == *b
|
|
}
|
|
|
|
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
|
|
if evt.StateKey == nil {
|
|
return
|
|
}
|
|
switch evt.Type {
|
|
case event.StateCreate, event.StateTombstone, event.StateRoomName, event.StateCanonicalAlias,
|
|
event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
|
|
if *evt.StateKey != "" {
|
|
return
|
|
}
|
|
default:
|
|
return
|
|
}
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("event_type", &evt.Type).
|
|
Stringer("event_id", evt.ID).
|
|
Msg("Failed to parse state event, skipping")
|
|
return
|
|
}
|
|
switch evt.Type {
|
|
case event.StateCreate:
|
|
updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
|
|
case event.StateTombstone:
|
|
updatedRoom.Tombstone, _ = evt.Content.Parsed.(*event.TombstoneEventContent)
|
|
case event.StateEncryption:
|
|
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
|
|
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
|
|
updatedRoom.EncryptionEvent = newEncryption
|
|
}
|
|
case event.StateRoomName:
|
|
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
|
|
if ok {
|
|
updatedRoom.Name = &content.Name
|
|
updatedRoom.NameQuality = database.NameQualityExplicit
|
|
if content.Name == "" {
|
|
if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
|
|
updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
} else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
|
|
updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
} else {
|
|
updatedRoom.NameQuality = database.NameQualityNil
|
|
}
|
|
}
|
|
}
|
|
case event.StateCanonicalAlias:
|
|
content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
|
|
if ok {
|
|
updatedRoom.CanonicalAlias = &content.Alias
|
|
if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
|
|
updatedRoom.Name = (*string)(&content.Alias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
if content.Alias == "" {
|
|
updatedRoom.NameQuality = database.NameQualityNil
|
|
}
|
|
}
|
|
}
|
|
case event.StateRoomAvatar:
|
|
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
|
if ok {
|
|
url, _ := content.URL.Parse()
|
|
updatedRoom.Avatar = &url
|
|
updatedRoom.ExplicitAvatar = true
|
|
}
|
|
case event.StateTopic:
|
|
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
|
|
if ok {
|
|
updatedRoom.Topic = &content.Topic
|
|
}
|
|
}
|
|
return
|
|
}
|