forked from Mirrors/gomuks
998 lines
33 KiB
Go
998 lines
33 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package hicli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"go.mau.fi/util/emojirunes"
|
|
"go.mau.fi/util/exzerolog"
|
|
"go.mau.fi/util/jsontime"
|
|
"go.mau.fi/util/ptr"
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/crypto/olm"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/format"
|
|
"maunium.net/go/mautrix/id"
|
|
"maunium.net/go/mautrix/pushrules"
|
|
|
|
"go.mau.fi/gomuks/pkg/hicli/database"
|
|
)
|
|
|
|
type syncContext struct {
|
|
shouldWakeupRequestQueue bool
|
|
|
|
evt *SyncComplete
|
|
}
|
|
|
|
func (h *HiClient) markSyncErrored(err error) {
|
|
stat := &SyncStatus{
|
|
Type: SyncStatusErrored,
|
|
Error: err.Error(),
|
|
ErrorCount: h.syncErrors,
|
|
LastSync: jsontime.UM(h.lastSync),
|
|
}
|
|
h.SyncStatus.Store(stat)
|
|
h.EventHandler(stat)
|
|
}
|
|
|
|
var (
|
|
syncOK = &SyncStatus{Type: SyncStatusOK}
|
|
syncWaiting = &SyncStatus{Type: SyncStatusWaiting}
|
|
)
|
|
|
|
func (h *HiClient) markSyncOK() {
|
|
if h.SyncStatus.Swap(syncOK) != syncOK {
|
|
h.EventHandler(syncOK)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
|
log := zerolog.Ctx(ctx)
|
|
postponedToDevices := resp.ToDevice.Events[:0]
|
|
for _, evt := range resp.ToDevice.Events {
|
|
evt.Type.Class = event.ToDeviceEventType
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
log.Warn().Err(err).
|
|
Stringer("event_type", &evt.Type).
|
|
Stringer("sender", evt.Sender).
|
|
Msg("Failed to parse to-device event, skipping")
|
|
continue
|
|
}
|
|
|
|
switch content := evt.Content.Parsed.(type) {
|
|
case *event.EncryptedEventContent:
|
|
h.Crypto.HandleEncryptedEvent(ctx, evt)
|
|
case *event.RoomKeyWithheldEventContent:
|
|
h.Crypto.HandleRoomKeyWithheld(ctx, content)
|
|
default:
|
|
postponedToDevices = append(postponedToDevices, evt)
|
|
}
|
|
}
|
|
resp.ToDevice.Events = postponedToDevices
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
|
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
|
go h.asyncPostProcessSyncResponse(ctx, resp, since)
|
|
syncCtx := ctx.Value(syncContextKey).(*syncContext)
|
|
if syncCtx.shouldWakeupRequestQueue {
|
|
h.WakeupRequestQueue()
|
|
}
|
|
if !h.firstSyncReceived {
|
|
h.firstSyncReceived = true
|
|
h.Client.Client.Transport.(*http.Transport).ResponseHeaderTimeout = 60 * time.Second
|
|
h.Client.Client.Timeout = 180 * time.Second
|
|
}
|
|
if !syncCtx.evt.IsEmpty() {
|
|
h.EventHandler(syncCtx.evt)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
|
for _, evt := range resp.ToDevice.Events {
|
|
switch content := evt.Content.Parsed.(type) {
|
|
case *event.SecretRequestEventContent:
|
|
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
|
|
case *event.RoomKeyRequestEventContent:
|
|
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
|
if len(resp.DeviceLists.Changed) > 0 {
|
|
zerolog.Ctx(ctx).Debug().
|
|
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
|
|
Msg("Marking changed device lists for tracked users as outdated")
|
|
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
|
|
}
|
|
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
|
}
|
|
|
|
accountData := make(map[event.Type]*database.AccountData, len(resp.AccountData.Events))
|
|
var err error
|
|
for _, evt := range resp.AccountData.Events {
|
|
evt.Type.Class = event.AccountDataEventType
|
|
accountData[evt.Type], err = h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
|
}
|
|
if evt.Type == event.AccountDataPushRules {
|
|
err = evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
|
|
} else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
|
|
h.receiveNewPushRules(ctx, pushRules.Ruleset)
|
|
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
|
|
}
|
|
}
|
|
}
|
|
ctx.Value(syncContextKey).(*syncContext).evt.AccountData = accountData
|
|
for roomID, room := range resp.Rooms.Join {
|
|
err := h.processSyncJoinedRoom(ctx, roomID, room)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
|
|
}
|
|
}
|
|
for roomID, room := range resp.Rooms.Leave {
|
|
err := h.processSyncLeftRoom(ctx, roomID, room)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
|
|
}
|
|
}
|
|
h.Account.NextBatch = resp.NextBatch
|
|
err = h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save next_batch: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
|
|
receiptList := make([]*database.Receipt, 0)
|
|
var newOwnReceipts []id.EventID
|
|
for eventID, receipts := range *content {
|
|
for receiptType, users := range receipts {
|
|
for userID, receiptInfo := range users {
|
|
if userID == h.Account.UserID {
|
|
newOwnReceipts = append(newOwnReceipts, eventID)
|
|
}
|
|
receiptList = append(receiptList, &database.Receipt{
|
|
UserID: userID,
|
|
ReceiptType: receiptType,
|
|
ThreadID: receiptInfo.ThreadID,
|
|
EventID: eventID,
|
|
Timestamp: jsontime.UM(receiptInfo.Timestamp),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
return receiptList, newOwnReceipts
|
|
}
|
|
|
|
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
|
|
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get room data: %w", err)
|
|
} else if existingRoomData == nil {
|
|
err = h.DB.Room.CreateRow(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to ensure room row exists: %w", err)
|
|
}
|
|
existingRoomData = &database.Room{
|
|
ID: roomID,
|
|
// Hack to set a default value for SortingTimestamp which is before all existing rooms,
|
|
// but not the same for all rooms without a timestamp.
|
|
SortingTimestamp: jsontime.UM(time.UnixMilli(time.Now().Unix())),
|
|
}
|
|
}
|
|
|
|
accountData := make(map[event.Type]*database.AccountData, len(room.AccountData.Events))
|
|
for _, evt := range room.AccountData.Events {
|
|
evt.Type.Class = event.AccountDataEventType
|
|
evt.RoomID = roomID
|
|
accountData[evt.Type], err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
|
}
|
|
}
|
|
var receiptsList []*database.Receipt
|
|
var newOwnReceipts []id.EventID
|
|
for _, evt := range room.Ephemeral.Events {
|
|
evt.Type.Class = event.EphemeralEventType
|
|
err = evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
|
|
continue
|
|
}
|
|
switch evt.Type {
|
|
case event.EphemeralEventReceipt:
|
|
list, ownList := h.receiptsToList(evt.Content.AsReceipt())
|
|
receiptsList = append(receiptsList, list...)
|
|
newOwnReceipts = append(newOwnReceipts, ownList...)
|
|
case event.EphemeralEventTyping:
|
|
go h.EventHandler(&Typing{
|
|
RoomID: roomID,
|
|
TypingEventContent: *evt.Content.AsTyping(),
|
|
})
|
|
}
|
|
}
|
|
err = h.processStateAndTimeline(
|
|
ctx,
|
|
existingRoomData,
|
|
&room.State,
|
|
&room.Timeline,
|
|
&room.Summary,
|
|
receiptsList,
|
|
newOwnReceipts,
|
|
accountData,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
|
|
zerolog.Ctx(ctx).Debug().Stringer("room_id", roomID).Msg("Deleting left room")
|
|
err := h.DB.Room.Delete(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete room: %w", err)
|
|
}
|
|
payload := ctx.Value(syncContextKey).(*syncContext).evt
|
|
payload.LeftRooms = append(payload.LeftRooms, roomID)
|
|
return nil
|
|
}
|
|
|
|
func isDecryptionErrorRetryable(err error) bool {
|
|
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
|
|
}
|
|
|
|
func removeReplyFallback(evt *event.Event) []byte {
|
|
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
|
|
return nil
|
|
}
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if ok && content.RelatesTo.GetReplyTo() != "" {
|
|
prevFormattedBody := content.FormattedBody
|
|
content.RemoveReplyFallback()
|
|
if content.FormattedBody != prevFormattedBody {
|
|
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
|
|
bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
|
|
if err == nil && err2 == nil {
|
|
return bytes
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
return nil, nil, "", err
|
|
}
|
|
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
|
if err != nil {
|
|
return nil, nil, "", err
|
|
}
|
|
withoutFallback := removeReplyFallback(decrypted)
|
|
if withoutFallback != nil {
|
|
return decrypted, withoutFallback, decrypted.Type.Type, nil
|
|
}
|
|
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
|
}
|
|
|
|
func (h *HiClient) addMediaCache(
|
|
ctx context.Context,
|
|
eventRowID database.EventRowID,
|
|
uri id.ContentURIString,
|
|
file *event.EncryptedFileInfo,
|
|
info *event.FileInfo,
|
|
fileName string,
|
|
) {
|
|
parsedMXC := uri.ParseOrIgnore()
|
|
if !parsedMXC.IsValid() {
|
|
return
|
|
}
|
|
cm := &database.Media{
|
|
MXC: parsedMXC,
|
|
FileName: fileName,
|
|
}
|
|
if file != nil {
|
|
cm.EncFile = &file.EncryptedFile
|
|
}
|
|
if info != nil {
|
|
cm.MimeType = info.MimeType
|
|
}
|
|
err := h.DB.Media.Put(ctx, cm)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("mxc", parsedMXC).
|
|
Int64("event_rowid", int64(eventRowID)).
|
|
Msg("Failed to add database media entry")
|
|
return
|
|
}
|
|
err = h.DB.Media.AddReference(ctx, eventRowID, parsedMXC)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("mxc", parsedMXC).
|
|
Int64("event_rowid", int64(eventRowID)).
|
|
Msg("Failed to add database media reference")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
|
|
switch evt.Type {
|
|
case event.EventMessage, event.EventSticker:
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
if content.File != nil {
|
|
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
|
|
} else if content.URL != "" {
|
|
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
|
|
}
|
|
if content.GetInfo().ThumbnailFile != nil {
|
|
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
|
|
} else if content.GetInfo().ThumbnailURL != "" {
|
|
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
|
|
}
|
|
case event.StateRoomAvatar:
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
|
|
case event.StateMember:
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
|
|
if !ok {
|
|
return
|
|
}
|
|
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) {
|
|
if evt.Type != event.EventMessage {
|
|
return nil, nil
|
|
}
|
|
_ = evt.Content.ParseRaw(evt.Type)
|
|
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
|
|
content = content.NewContent
|
|
}
|
|
if content != nil {
|
|
var sanitizedHTML, editSource string
|
|
var wasPlaintext, hasMath, bigEmoji bool
|
|
var inlineImages []id.ContentURI
|
|
if content.Format == event.FormatHTML && content.FormattedBody != "" {
|
|
var err error
|
|
sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("event_id", dbEvt.ID).
|
|
Msg("Failed to sanitize HTML")
|
|
}
|
|
hasMath = strings.Contains(sanitizedHTML, "<hicli-math")
|
|
if len(inlineImages) > 0 && dbEvt.RowID != 0 {
|
|
for _, uri := range inlineImages {
|
|
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
|
|
}
|
|
inlineImages = nil
|
|
}
|
|
if dbEvt.LocalContent != nil && dbEvt.LocalContent.EditSource != "" {
|
|
editSource = dbEvt.LocalContent.EditSource
|
|
} else if evt.Sender == h.Account.UserID {
|
|
editSource, _ = format.HTMLToMarkdownFull(htmlToMarkdownForInput, content.FormattedBody)
|
|
if content.MsgType == event.MsgEmote {
|
|
editSource = "/me " + editSource
|
|
} else if content.MsgType == event.MsgNotice {
|
|
editSource = "/notice " + editSource
|
|
}
|
|
}
|
|
} else {
|
|
hasSpecialCharacters := false
|
|
for _, char := range content.Body {
|
|
if char == '<' || char == '>' || char == '&' || char == '.' || char == ':' {
|
|
hasSpecialCharacters = true
|
|
break
|
|
}
|
|
}
|
|
if hasSpecialCharacters {
|
|
var builder strings.Builder
|
|
builder.Grow(len(content.Body) + builderPreallocBuffer)
|
|
linkifyAndWriteBytes(&builder, []byte(content.Body))
|
|
sanitizedHTML = builder.String()
|
|
} else if len(content.Body) < 100 && emojirunes.IsOnlyEmojis(content.Body) {
|
|
bigEmoji = true
|
|
}
|
|
if content.MsgType == event.MsgEmote {
|
|
editSource = "/me " + content.Body
|
|
} else if content.MsgType == event.MsgNotice {
|
|
editSource = "/notice " + content.Body
|
|
}
|
|
wasPlaintext = true
|
|
}
|
|
return &database.LocalContent{
|
|
SanitizedHTML: sanitizedHTML,
|
|
HTMLVersion: CurrentHTMLSanitizerVersion,
|
|
WasPlaintext: wasPlaintext,
|
|
BigEmoji: bigEmoji,
|
|
HasMath: hasMath,
|
|
EditSource: editSource,
|
|
}, inlineImages
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
const CurrentHTMLSanitizerVersion = 7
|
|
|
|
func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) {
|
|
if (evt.Type != event.EventMessage.Type && evt.DecryptedType != event.EventMessage.Type) ||
|
|
evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
|
|
return
|
|
}
|
|
evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
|
|
err := h.DB.Event.UpdateLocalContent(ctx, evt)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).
|
|
Stringer("event_id", evt.ID).
|
|
Msg("Failed to update local content")
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) {
|
|
if dbEvt.RowID != 0 {
|
|
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
|
}
|
|
if evt.Sender != h.Account.UserID {
|
|
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
|
|
}
|
|
dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt)
|
|
return
|
|
}
|
|
|
|
func (h *HiClient) processEvent(
|
|
ctx context.Context,
|
|
evt *event.Event,
|
|
llSummary *mautrix.LazyLoadSummary,
|
|
decryptionQueue map[id.SessionID]*database.SessionRequest,
|
|
checkDB bool,
|
|
) (*database.Event, error) {
|
|
if checkDB {
|
|
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
|
|
} else if dbEvt != nil {
|
|
return dbEvt, nil
|
|
}
|
|
}
|
|
dbEvt := database.MautrixToEvent(evt)
|
|
contentWithoutFallback := removeReplyFallback(evt)
|
|
if contentWithoutFallback != nil {
|
|
dbEvt.Content = contentWithoutFallback
|
|
}
|
|
var decryptionErr error
|
|
var decryptedMautrixEvt *event.Event
|
|
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
|
|
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
|
if decryptionErr != nil {
|
|
dbEvt.DecryptionError = decryptionErr.Error()
|
|
}
|
|
} else if evt.Type == event.EventRedaction {
|
|
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
|
|
var err error
|
|
evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
|
|
if err != nil {
|
|
return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
|
|
}
|
|
} else if evt.Redacts == "" {
|
|
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
|
|
}
|
|
}
|
|
var inlineImages []id.ContentURI
|
|
if decryptedMautrixEvt != nil {
|
|
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
|
|
} else {
|
|
inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
|
|
}
|
|
_, err := h.DB.Event.Upsert(ctx, dbEvt)
|
|
if err != nil {
|
|
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
|
}
|
|
if decryptedMautrixEvt != nil {
|
|
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
|
|
} else {
|
|
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
|
}
|
|
for _, uri := range inlineImages {
|
|
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
|
|
}
|
|
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
|
|
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
|
|
if !ok {
|
|
req = &database.SessionRequest{
|
|
RoomID: evt.RoomID,
|
|
SessionID: dbEvt.MegolmSessionID,
|
|
Sender: evt.Sender,
|
|
}
|
|
}
|
|
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
|
|
req.MinIndex = min(uint32(minIndex), req.MinIndex)
|
|
if decryptionQueue != nil {
|
|
decryptionQueue[dbEvt.MegolmSessionID] = req
|
|
} else {
|
|
err = h.DB.SessionRequest.Put(ctx, req)
|
|
if err != nil {
|
|
zerolog.Ctx(ctx).Err(err).
|
|
Stringer("session_id", dbEvt.MegolmSessionID).
|
|
Msg("Failed to save session request")
|
|
} else {
|
|
h.WakeupRequestQueue()
|
|
}
|
|
}
|
|
}
|
|
return dbEvt, err
|
|
}
|
|
|
|
var unsetSortingTimestamp = time.UnixMilli(1000000000000)
|
|
|
|
func (h *HiClient) processStateAndTimeline(
|
|
ctx context.Context,
|
|
room *database.Room,
|
|
state *mautrix.SyncEventsList,
|
|
timeline *mautrix.SyncTimeline,
|
|
summary *mautrix.LazyLoadSummary,
|
|
receipts []*database.Receipt,
|
|
newOwnReceipts []id.EventID,
|
|
accountData map[event.Type]*database.AccountData,
|
|
) error {
|
|
updatedRoom := &database.Room{
|
|
ID: room.ID,
|
|
|
|
SortingTimestamp: room.SortingTimestamp,
|
|
NameQuality: room.NameQuality,
|
|
UnreadCounts: room.UnreadCounts,
|
|
}
|
|
heroesChanged := false
|
|
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
|
|
summary = room.LazyLoadSummary
|
|
} else if room.LazyLoadSummary == nil ||
|
|
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
|
|
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
|
|
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
|
|
updatedRoom.LazyLoadSummary = summary
|
|
heroesChanged = true
|
|
}
|
|
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
|
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
|
|
newNotifications := make([]SyncNotification, 0)
|
|
var recalculatePreviewEvent, unreadMessagesWereMaybeRedacted bool
|
|
var newUnreadCounts database.UnreadCounts
|
|
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
|
|
if rowID != 0 {
|
|
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
|
|
} else {
|
|
dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get redaction target: %w", err)
|
|
} else if dbEvt == nil {
|
|
return nil, nil
|
|
}
|
|
allNewEvents = append(allNewEvents, dbEvt)
|
|
return dbEvt, nil
|
|
}
|
|
processRedaction := func(evt *event.Event) error {
|
|
dbEvt, err := addOldEvent(0, evt.Redacts)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get redaction target: %w", err)
|
|
}
|
|
if dbEvt == nil {
|
|
return nil
|
|
}
|
|
if dbEvt.UnreadType > 0 {
|
|
unreadMessagesWereMaybeRedacted = true
|
|
}
|
|
if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
|
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get relation target of redaction target: %w", err)
|
|
}
|
|
}
|
|
if updatedRoom.PreviewEventRowID == dbEvt.RowID {
|
|
updatedRoom.PreviewEventRowID = 0
|
|
recalculatePreviewEvent = true
|
|
}
|
|
return nil
|
|
}
|
|
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
|
|
evt.RoomID = room.ID
|
|
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, evt.Unsigned.TransactionID != "")
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
if isUnread {
|
|
if dbEvt.UnreadType.Is(database.UnreadTypeNotify) && h.firstSyncReceived {
|
|
newNotifications = append(newNotifications, SyncNotification{
|
|
RowID: dbEvt.RowID,
|
|
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
|
|
})
|
|
}
|
|
newUnreadCounts.AddOne(dbEvt.UnreadType)
|
|
}
|
|
if isTimeline {
|
|
if dbEvt.CanUseForPreview() {
|
|
updatedRoom.PreviewEventRowID = dbEvt.RowID
|
|
recalculatePreviewEvent = false
|
|
}
|
|
updatedRoom.BumpSortingTimestamp(dbEvt)
|
|
}
|
|
if evt.StateKey != nil {
|
|
var membership event.Membership
|
|
if evt.Type == event.StateMember {
|
|
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
|
|
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
|
|
heroesChanged = true
|
|
}
|
|
} else if evt.Type == event.StateElementFunctionalMembers {
|
|
heroesChanged = true
|
|
}
|
|
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
|
|
}
|
|
processImportantEvent(ctx, evt, room, updatedRoom)
|
|
}
|
|
allNewEvents = append(allNewEvents, dbEvt)
|
|
if evt.Type == event.EventRedaction && evt.Redacts != "" {
|
|
err = processRedaction(evt)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to process redaction: %w", err)
|
|
}
|
|
} else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
|
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
|
if err != nil {
|
|
return -1, fmt.Errorf("failed to get relation target of event: %w", err)
|
|
}
|
|
}
|
|
return dbEvt.RowID, nil
|
|
}
|
|
changedState := make(map[event.Type]map[string]database.EventRowID)
|
|
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
|
|
if _, ok := changedState[evtType]; !ok {
|
|
changedState[evtType] = make(map[string]database.EventRowID)
|
|
}
|
|
changedState[evtType][stateKey] = rowID
|
|
}
|
|
for _, evt := range state.Events {
|
|
evt.Type.Class = event.StateEventType
|
|
rowID, err := processNewEvent(evt, false, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
setNewState(evt.Type, *evt.StateKey, rowID)
|
|
}
|
|
var timelineRowTuples []database.TimelineRowTuple
|
|
var err error
|
|
if len(timeline.Events) > 0 {
|
|
timelineIDs := make([]database.EventRowID, len(timeline.Events))
|
|
readUpToIndex := -1
|
|
for i := len(timeline.Events) - 1; i >= 0; i-- {
|
|
evt := timeline.Events[i]
|
|
isRead := slices.Contains(newOwnReceipts, evt.ID)
|
|
isOwnEvent := evt.Sender == h.Account.UserID
|
|
if isRead || isOwnEvent {
|
|
readUpToIndex = i
|
|
// Reset unread counts if we see our own read receipt in the timeline.
|
|
// It'll be updated with new unreads (if any) at the end.
|
|
updatedRoom.UnreadCounts = database.UnreadCounts{}
|
|
if !isRead {
|
|
receipts = append(receipts, &database.Receipt{
|
|
RoomID: room.ID,
|
|
UserID: h.Account.UserID,
|
|
ReceiptType: event.ReceiptTypeRead,
|
|
EventID: evt.ID,
|
|
Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
|
|
})
|
|
newOwnReceipts = append(newOwnReceipts, evt.ID)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
for i, evt := range timeline.Events {
|
|
if evt.StateKey != nil {
|
|
evt.Type.Class = event.StateEventType
|
|
} else {
|
|
evt.Type.Class = event.MessageEventType
|
|
}
|
|
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if evt.StateKey != nil {
|
|
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
|
|
}
|
|
}
|
|
if updatedRoom.SortingTimestamp.Before(unsetSortingTimestamp) && len(timeline.Events) > 0 {
|
|
updatedRoom.SortingTimestamp = jsontime.UM(time.UnixMilli(timeline.Events[len(timeline.Events)-1].Timestamp))
|
|
}
|
|
for _, entry := range decryptionQueue {
|
|
err = h.DB.SessionRequest.Put(ctx, entry)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
|
}
|
|
}
|
|
if len(decryptionQueue) > 0 {
|
|
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
|
}
|
|
if timeline.Limited {
|
|
err = h.DB.Timeline.Clear(ctx, room.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to clear old timeline: %w", err)
|
|
}
|
|
updatedRoom.PrevBatch = timeline.PrevBatch
|
|
h.paginationInterrupterLock.Lock()
|
|
if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
|
|
interrupt(ErrTimelineReset)
|
|
}
|
|
h.paginationInterrupterLock.Unlock()
|
|
}
|
|
timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to append timeline: %w", err)
|
|
}
|
|
} else {
|
|
timelineRowTuples = make([]database.TimelineRowTuple, 0)
|
|
}
|
|
if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
|
|
updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to recalculate preview event: %w", err)
|
|
}
|
|
_, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get preview event: %w", err)
|
|
}
|
|
}
|
|
// Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
|
|
if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
|
|
name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to calculate room name: %w", err)
|
|
}
|
|
updatedRoom.Name = &name
|
|
updatedRoom.NameQuality = database.NameQualityParticipants
|
|
if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
|
|
updatedRoom.Avatar = &dmAvatarURL
|
|
}
|
|
}
|
|
mu, ok := accountData[event.AccountDataMarkedUnread]
|
|
if ok {
|
|
updatedRoom.MarkedUnread = ptr.Ptr(gjson.GetBytes(mu.Content, "unread").Bool())
|
|
}
|
|
|
|
if len(receipts) > 0 {
|
|
err = h.DB.Receipt.PutMany(ctx, room.ID, receipts...)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save receipts: %w", err)
|
|
}
|
|
}
|
|
if !room.UnreadCounts.IsZero() && ((len(newOwnReceipts) > 0 && newUnreadCounts.IsZero()) || unreadMessagesWereMaybeRedacted) {
|
|
updatedRoom.UnreadCounts, err = h.DB.Room.CalculateUnreads(ctx, room.ID, h.Account.UserID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to recalculate unread counts: %w", err)
|
|
}
|
|
} else {
|
|
updatedRoom.UnreadCounts.Add(newUnreadCounts)
|
|
}
|
|
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
|
|
updatedRoom.PrevBatch = timeline.PrevBatch
|
|
}
|
|
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
|
if roomChanged {
|
|
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save room data: %w", err)
|
|
}
|
|
}
|
|
// TODO why is *old* unread count sometimes zero when processing the read receipt that is making it zero?
|
|
if roomChanged || len(accountData) > 0 || len(newOwnReceipts) > 0 || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
|
|
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
|
|
Meta: room,
|
|
Timeline: timelineRowTuples,
|
|
AccountData: accountData,
|
|
State: changedState,
|
|
Reset: timeline.Limited,
|
|
Events: allNewEvents,
|
|
Notifications: newNotifications,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func joinMemberNames(names []string, totalCount int) string {
|
|
if len(names) == 1 {
|
|
return names[0]
|
|
} else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
|
|
return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
|
|
} else {
|
|
return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
|
|
var primaryAvatarURL id.ContentURI
|
|
if summary == nil || len(summary.Heroes) == 0 {
|
|
return "Empty room", primaryAvatarURL, nil
|
|
}
|
|
var functionalMembers []id.UserID
|
|
functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
|
|
if err != nil {
|
|
return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
|
|
} else if functionalMembersEvt != nil {
|
|
mautrixEvt := functionalMembersEvt.AsRawMautrix()
|
|
_ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
|
|
content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
|
|
if ok {
|
|
functionalMembers = content.ServiceMembers
|
|
}
|
|
}
|
|
var members, leftMembers []string
|
|
var memberCount int
|
|
if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
|
|
memberCount = *summary.JoinedMemberCount
|
|
} else if summary.InvitedMemberCount != nil {
|
|
memberCount = *summary.InvitedMemberCount
|
|
}
|
|
for _, hero := range summary.Heroes {
|
|
if slices.Contains(functionalMembers, hero) {
|
|
memberCount--
|
|
continue
|
|
} else if len(members) >= 5 {
|
|
break
|
|
}
|
|
heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
|
|
if err != nil {
|
|
return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
|
|
} else if heroEvt == nil {
|
|
leftMembers = append(leftMembers, hero.String())
|
|
continue
|
|
}
|
|
membership := gjson.GetBytes(heroEvt.Content, "membership").Str
|
|
name := gjson.GetBytes(heroEvt.Content, "displayname").Str
|
|
if name == "" {
|
|
name = hero.String()
|
|
}
|
|
avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
|
|
if avatarURL != "" {
|
|
primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
|
|
}
|
|
if membership == "join" || membership == "invite" {
|
|
members = append(members, name)
|
|
} else {
|
|
leftMembers = append(leftMembers, name)
|
|
}
|
|
}
|
|
if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
|
|
primaryAvatarURL = id.ContentURI{}
|
|
}
|
|
if len(members) > 0 {
|
|
return joinMemberNames(members, memberCount), primaryAvatarURL, nil
|
|
} else if len(leftMembers) > 0 {
|
|
return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
|
|
} else {
|
|
return "Empty room", primaryAvatarURL, nil
|
|
}
|
|
}
|
|
|
|
func intPtrEqual(a, b *int) bool {
|
|
if a == nil || b == nil {
|
|
return a == b
|
|
}
|
|
return *a == *b
|
|
}
|
|
|
|
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
|
|
if evt.StateKey == nil {
|
|
return
|
|
}
|
|
switch evt.Type {
|
|
case event.StateCreate, event.StateTombstone, event.StateRoomName, event.StateCanonicalAlias,
|
|
event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
|
|
if *evt.StateKey != "" {
|
|
return
|
|
}
|
|
default:
|
|
return
|
|
}
|
|
err := evt.Content.ParseRaw(evt.Type)
|
|
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
|
zerolog.Ctx(ctx).Warn().Err(err).
|
|
Stringer("event_type", &evt.Type).
|
|
Stringer("event_id", evt.ID).
|
|
Msg("Failed to parse state event, skipping")
|
|
return
|
|
}
|
|
switch evt.Type {
|
|
case event.StateCreate:
|
|
updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
|
|
case event.StateTombstone:
|
|
updatedRoom.Tombstone, _ = evt.Content.Parsed.(*event.TombstoneEventContent)
|
|
case event.StateEncryption:
|
|
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
|
|
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
|
|
updatedRoom.EncryptionEvent = newEncryption
|
|
}
|
|
case event.StateRoomName:
|
|
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
|
|
if ok {
|
|
updatedRoom.Name = &content.Name
|
|
updatedRoom.NameQuality = database.NameQualityExplicit
|
|
if content.Name == "" {
|
|
if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
|
|
updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
} else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
|
|
updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
} else {
|
|
updatedRoom.NameQuality = database.NameQualityNil
|
|
}
|
|
}
|
|
}
|
|
case event.StateCanonicalAlias:
|
|
content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
|
|
if ok {
|
|
updatedRoom.CanonicalAlias = &content.Alias
|
|
if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
|
|
updatedRoom.Name = (*string)(&content.Alias)
|
|
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
|
if content.Alias == "" {
|
|
updatedRoom.NameQuality = database.NameQualityNil
|
|
}
|
|
}
|
|
}
|
|
case event.StateRoomAvatar:
|
|
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
|
if ok {
|
|
url, _ := content.URL.Parse()
|
|
updatedRoom.Avatar = &url
|
|
updatedRoom.ExplicitAvatar = true
|
|
}
|
|
case event.StateTopic:
|
|
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
|
|
if ok {
|
|
updatedRoom.Topic = &content.Topic
|
|
}
|
|
}
|
|
return
|
|
}
|