hicli/database: add flag for events that had reply fallbacks removed

This commit is contained in:
Tulir Asokan 2024-12-18 00:46:59 +02:00
parent 29b787f94a
commit aa8148f5af
3 changed files with 47 additions and 22 deletions

View file

@ -327,12 +327,17 @@ func (m EventRowID) GetMassInsertValues() [1]any {
}
type LocalContent struct {
SanitizedHTML string `json:"sanitized_html,omitempty"`
HTMLVersion int `json:"html_version,omitempty"`
WasPlaintext bool `json:"was_plaintext,omitempty"`
BigEmoji bool `json:"big_emoji,omitempty"`
HasMath bool `json:"has_math,omitempty"`
EditSource string `json:"edit_source,omitempty"`
SanitizedHTML string `json:"sanitized_html,omitempty"`
HTMLVersion int `json:"html_version,omitempty"`
WasPlaintext bool `json:"was_plaintext,omitempty"`
BigEmoji bool `json:"big_emoji,omitempty"`
HasMath bool `json:"has_math,omitempty"`
EditSource string `json:"edit_source,omitempty"`
ReplyFallbackRemoved bool `json:"reply_fallback_removed,omitempty"`
}
func (c *LocalContent) GetReplyFallbackRemoved() bool {
return c != nil && c.ReplyFallbackRemoved
}
type Event struct {
@ -545,3 +550,10 @@ func (e *Event) BumpsSortingTimestamp() bool {
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
e.RelationType != event.RelReplace
}
func (e *Event) MarkReplyFallbackRemoved() {
if e.LocalContent == nil {
e.LocalContent = &LocalContent{}
}
e.LocalContent.ReplyFallbackRemoved = true
}

View file

@ -59,7 +59,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
}
var mautrixEvt *event.Event
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
mautrixEvt, err = h.decryptEventInto(ctx, evt.AsRawMautrix(), evt)
if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {

View file

@ -291,20 +291,34 @@ func removeReplyFallback(evt *event.Event) []byte {
return nil
}
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, bool, string, error) {
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, nil, "", err
return nil, nil, false, "", err
}
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil {
return nil, nil, "", err
return nil, nil, false, "", err
}
withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil {
return decrypted, withoutFallback, decrypted.Type.Type, nil
return decrypted, withoutFallback, true, decrypted.Type.Type, nil
}
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
return decrypted, decrypted.Content.VeryRaw, false, decrypted.Type.Type, nil
}
func (h *HiClient) decryptEventInto(ctx context.Context, evt *event.Event, dbEvt *database.Event) (*event.Event, error) {
decryptedEvt, rawContent, fallbackRemoved, decryptedType, err := h.decryptEvent(ctx, evt)
if err != nil {
dbEvt.DecryptionError = err.Error()
return nil, err
}
dbEvt.Decrypted = rawContent
if fallbackRemoved {
dbEvt.MarkReplyFallbackRemoved()
}
dbEvt.DecryptedType = decryptedType
return decryptedEvt, nil
}
func (h *HiClient) addMediaCache(
@ -445,12 +459,13 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
wasPlaintext = true
}
return &database.LocalContent{
SanitizedHTML: sanitizedHTML,
HTMLVersion: CurrentHTMLSanitizerVersion,
WasPlaintext: wasPlaintext,
BigEmoji: bigEmoji,
HasMath: hasMath,
EditSource: editSource,
SanitizedHTML: sanitizedHTML,
HTMLVersion: CurrentHTMLSanitizerVersion,
WasPlaintext: wasPlaintext,
BigEmoji: bigEmoji,
HasMath: hasMath,
EditSource: editSource,
ReplyFallbackRemoved: dbEvt.LocalContent.GetReplyFallbackRemoved(),
}, inlineImages
}
return nil, nil
@ -502,14 +517,12 @@ func (h *HiClient) processEvent(
contentWithoutFallback := removeReplyFallback(evt)
if contentWithoutFallback != nil {
dbEvt.Content = contentWithoutFallback
dbEvt.MarkReplyFallbackRemoved()
}
var decryptionErr error
var decryptedMautrixEvt *event.Event
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
decryptedMautrixEvt, decryptionErr = h.decryptEventInto(ctx, evt, dbEvt)
} else if evt.Type == event.EventRedaction {
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
var err error