hicli/database: add flag for events that had reply fallbacks removed

This commit is contained in:
Tulir Asokan 2024-12-18 00:46:59 +02:00
parent 29b787f94a
commit aa8148f5af
3 changed files with 47 additions and 22 deletions

View file

@ -333,6 +333,11 @@ type LocalContent struct {
BigEmoji bool `json:"big_emoji,omitempty"` BigEmoji bool `json:"big_emoji,omitempty"`
HasMath bool `json:"has_math,omitempty"` HasMath bool `json:"has_math,omitempty"`
EditSource string `json:"edit_source,omitempty"` EditSource string `json:"edit_source,omitempty"`
ReplyFallbackRemoved bool `json:"reply_fallback_removed,omitempty"`
}
func (c *LocalContent) GetReplyFallbackRemoved() bool {
return c != nil && c.ReplyFallbackRemoved
} }
type Event struct { type Event struct {
@ -545,3 +550,10 @@ func (e *Event) BumpsSortingTimestamp() bool {
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
e.RelationType != event.RelReplace e.RelationType != event.RelReplace
} }
func (e *Event) MarkReplyFallbackRemoved() {
if e.LocalContent == nil {
e.LocalContent = &LocalContent{}
}
e.LocalContent.ReplyFallbackRemoved = true
}

View file

@ -59,7 +59,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
} }
var mautrixEvt *event.Event var mautrixEvt *event.Event
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) mautrixEvt, err = h.decryptEventInto(ctx, evt.AsRawMautrix(), evt)
if err != nil { if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else { } else {

View file

@ -291,20 +291,34 @@ func removeReplyFallback(evt *event.Event) []byte {
return nil return nil
} }
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, bool, string, error) {
err := evt.Content.ParseRaw(evt.Type) err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, nil, "", err return nil, nil, false, "", err
} }
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil { if err != nil {
return nil, nil, "", err return nil, nil, false, "", err
} }
withoutFallback := removeReplyFallback(decrypted) withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil { if withoutFallback != nil {
return decrypted, withoutFallback, decrypted.Type.Type, nil return decrypted, withoutFallback, true, decrypted.Type.Type, nil
} }
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil return decrypted, decrypted.Content.VeryRaw, false, decrypted.Type.Type, nil
}
func (h *HiClient) decryptEventInto(ctx context.Context, evt *event.Event, dbEvt *database.Event) (*event.Event, error) {
decryptedEvt, rawContent, fallbackRemoved, decryptedType, err := h.decryptEvent(ctx, evt)
if err != nil {
dbEvt.DecryptionError = err.Error()
return nil, err
}
dbEvt.Decrypted = rawContent
if fallbackRemoved {
dbEvt.MarkReplyFallbackRemoved()
}
dbEvt.DecryptedType = decryptedType
return decryptedEvt, nil
} }
func (h *HiClient) addMediaCache( func (h *HiClient) addMediaCache(
@ -451,6 +465,7 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
BigEmoji: bigEmoji, BigEmoji: bigEmoji,
HasMath: hasMath, HasMath: hasMath,
EditSource: editSource, EditSource: editSource,
ReplyFallbackRemoved: dbEvt.LocalContent.GetReplyFallbackRemoved(),
}, inlineImages }, inlineImages
} }
return nil, nil return nil, nil
@ -502,14 +517,12 @@ func (h *HiClient) processEvent(
contentWithoutFallback := removeReplyFallback(evt) contentWithoutFallback := removeReplyFallback(evt)
if contentWithoutFallback != nil { if contentWithoutFallback != nil {
dbEvt.Content = contentWithoutFallback dbEvt.Content = contentWithoutFallback
dbEvt.MarkReplyFallbackRemoved()
} }
var decryptionErr error var decryptionErr error
var decryptedMautrixEvt *event.Event var decryptedMautrixEvt *event.Event
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) decryptedMautrixEvt, decryptionErr = h.decryptEventInto(ctx, evt, dbEvt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
} else if evt.Type == event.EventRedaction { } else if evt.Type == event.EventRedaction {
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
var err error var err error