diff --git a/pkg/hicli/database/event.go b/pkg/hicli/database/event.go index 3940614..4934d1e 100644 --- a/pkg/hicli/database/event.go +++ b/pkg/hicli/database/event.go @@ -327,12 +327,17 @@ func (m EventRowID) GetMassInsertValues() [1]any { } type LocalContent struct { - SanitizedHTML string `json:"sanitized_html,omitempty"` - HTMLVersion int `json:"html_version,omitempty"` - WasPlaintext bool `json:"was_plaintext,omitempty"` - BigEmoji bool `json:"big_emoji,omitempty"` - HasMath bool `json:"has_math,omitempty"` - EditSource string `json:"edit_source,omitempty"` + SanitizedHTML string `json:"sanitized_html,omitempty"` + HTMLVersion int `json:"html_version,omitempty"` + WasPlaintext bool `json:"was_plaintext,omitempty"` + BigEmoji bool `json:"big_emoji,omitempty"` + HasMath bool `json:"has_math,omitempty"` + EditSource string `json:"edit_source,omitempty"` + ReplyFallbackRemoved bool `json:"reply_fallback_removed,omitempty"` +} + +func (c *LocalContent) GetReplyFallbackRemoved() bool { + return c != nil && c.ReplyFallbackRemoved } type Event struct { @@ -545,3 +550,10 @@ func (e *Event) BumpsSortingTimestamp() bool { return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && e.RelationType != event.RelReplace } + +func (e *Event) MarkReplyFallbackRemoved() { + if e.LocalContent == nil { + e.LocalContent = &LocalContent{} + } + e.LocalContent.ReplyFallbackRemoved = true +} diff --git a/pkg/hicli/decryptionqueue.go b/pkg/hicli/decryptionqueue.go index 772f2a7..d5f669a 100644 --- a/pkg/hicli/decryptionqueue.go +++ b/pkg/hicli/decryptionqueue.go @@ -59,7 +59,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro } var mautrixEvt *event.Event - mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + mautrixEvt, err = h.decryptEventInto(ctx, evt.AsRawMautrix(), evt) if err != nil { log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") } else { diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index 2c014c1..ab15a03 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -291,20 +291,34 @@ func removeReplyFallback(evt *event.Event) []byte { return nil } -func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, bool, string, error) { err := evt.Content.ParseRaw(evt.Type) if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - return nil, nil, "", err + return nil, nil, false, "", err } decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) if err != nil { - return nil, nil, "", err + return nil, nil, false, "", err } withoutFallback := removeReplyFallback(decrypted) if withoutFallback != nil { - return decrypted, withoutFallback, decrypted.Type.Type, nil + return decrypted, withoutFallback, true, decrypted.Type.Type, nil } - return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil + return decrypted, decrypted.Content.VeryRaw, false, decrypted.Type.Type, nil +} + +func (h *HiClient) decryptEventInto(ctx context.Context, evt *event.Event, dbEvt *database.Event) (*event.Event, error) { + decryptedEvt, rawContent, fallbackRemoved, decryptedType, err := h.decryptEvent(ctx, evt) + if err != nil { + dbEvt.DecryptionError = err.Error() + return nil, err + } + dbEvt.Decrypted = rawContent + if fallbackRemoved { + dbEvt.MarkReplyFallbackRemoved() + } + dbEvt.DecryptedType = decryptedType + return decryptedEvt, nil } func (h *HiClient) addMediaCache( @@ -445,12 +459,13 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev wasPlaintext = true } return &database.LocalContent{ - SanitizedHTML: sanitizedHTML, - HTMLVersion: CurrentHTMLSanitizerVersion, - WasPlaintext: wasPlaintext, - BigEmoji: bigEmoji, - HasMath: hasMath, - EditSource: editSource, + SanitizedHTML: sanitizedHTML, + HTMLVersion: CurrentHTMLSanitizerVersion, + WasPlaintext: wasPlaintext, + BigEmoji: bigEmoji, + HasMath: hasMath, + EditSource: editSource, + ReplyFallbackRemoved: dbEvt.LocalContent.GetReplyFallbackRemoved(), }, inlineImages } return nil, nil @@ -502,14 +517,12 @@ func (h *HiClient) processEvent( contentWithoutFallback := removeReplyFallback(evt) if contentWithoutFallback != nil { dbEvt.Content = contentWithoutFallback + dbEvt.MarkReplyFallbackRemoved() } var decryptionErr error var decryptedMautrixEvt *event.Event if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { - decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) - if decryptionErr != nil { - dbEvt.DecryptionError = decryptionErr.Error() - } + decryptedMautrixEvt, decryptionErr = h.decryptEventInto(ctx, evt, dbEvt) } else if evt.Type == event.EventRedaction { if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { var err error