forked from Mirrors/gomuks
hicli/database: add flag for events that had reply fallbacks removed
This commit is contained in:
parent
29b787f94a
commit
aa8148f5af
3 changed files with 47 additions and 22 deletions
|
@ -327,12 +327,17 @@ func (m EventRowID) GetMassInsertValues() [1]any {
|
|||
}
|
||||
|
||||
type LocalContent struct {
|
||||
SanitizedHTML string `json:"sanitized_html,omitempty"`
|
||||
HTMLVersion int `json:"html_version,omitempty"`
|
||||
WasPlaintext bool `json:"was_plaintext,omitempty"`
|
||||
BigEmoji bool `json:"big_emoji,omitempty"`
|
||||
HasMath bool `json:"has_math,omitempty"`
|
||||
EditSource string `json:"edit_source,omitempty"`
|
||||
SanitizedHTML string `json:"sanitized_html,omitempty"`
|
||||
HTMLVersion int `json:"html_version,omitempty"`
|
||||
WasPlaintext bool `json:"was_plaintext,omitempty"`
|
||||
BigEmoji bool `json:"big_emoji,omitempty"`
|
||||
HasMath bool `json:"has_math,omitempty"`
|
||||
EditSource string `json:"edit_source,omitempty"`
|
||||
ReplyFallbackRemoved bool `json:"reply_fallback_removed,omitempty"`
|
||||
}
|
||||
|
||||
func (c *LocalContent) GetReplyFallbackRemoved() bool {
|
||||
return c != nil && c.ReplyFallbackRemoved
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
|
@ -545,3 +550,10 @@ func (e *Event) BumpsSortingTimestamp() bool {
|
|||
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
|
||||
e.RelationType != event.RelReplace
|
||||
}
|
||||
|
||||
func (e *Event) MarkReplyFallbackRemoved() {
|
||||
if e.LocalContent == nil {
|
||||
e.LocalContent = &LocalContent{}
|
||||
}
|
||||
e.LocalContent.ReplyFallbackRemoved = true
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro
|
|||
}
|
||||
|
||||
var mautrixEvt *event.Event
|
||||
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
|
||||
mautrixEvt, err = h.decryptEventInto(ctx, evt.AsRawMautrix(), evt)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
||||
} else {
|
||||
|
|
|
@ -291,20 +291,34 @@ func removeReplyFallback(evt *event.Event) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, bool, string, error) {
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
return nil, nil, "", err
|
||||
return nil, nil, false, "", err
|
||||
}
|
||||
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
return nil, nil, false, "", err
|
||||
}
|
||||
withoutFallback := removeReplyFallback(decrypted)
|
||||
if withoutFallback != nil {
|
||||
return decrypted, withoutFallback, decrypted.Type.Type, nil
|
||||
return decrypted, withoutFallback, true, decrypted.Type.Type, nil
|
||||
}
|
||||
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
||||
return decrypted, decrypted.Content.VeryRaw, false, decrypted.Type.Type, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) decryptEventInto(ctx context.Context, evt *event.Event, dbEvt *database.Event) (*event.Event, error) {
|
||||
decryptedEvt, rawContent, fallbackRemoved, decryptedType, err := h.decryptEvent(ctx, evt)
|
||||
if err != nil {
|
||||
dbEvt.DecryptionError = err.Error()
|
||||
return nil, err
|
||||
}
|
||||
dbEvt.Decrypted = rawContent
|
||||
if fallbackRemoved {
|
||||
dbEvt.MarkReplyFallbackRemoved()
|
||||
}
|
||||
dbEvt.DecryptedType = decryptedType
|
||||
return decryptedEvt, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) addMediaCache(
|
||||
|
@ -445,12 +459,13 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
|
|||
wasPlaintext = true
|
||||
}
|
||||
return &database.LocalContent{
|
||||
SanitizedHTML: sanitizedHTML,
|
||||
HTMLVersion: CurrentHTMLSanitizerVersion,
|
||||
WasPlaintext: wasPlaintext,
|
||||
BigEmoji: bigEmoji,
|
||||
HasMath: hasMath,
|
||||
EditSource: editSource,
|
||||
SanitizedHTML: sanitizedHTML,
|
||||
HTMLVersion: CurrentHTMLSanitizerVersion,
|
||||
WasPlaintext: wasPlaintext,
|
||||
BigEmoji: bigEmoji,
|
||||
HasMath: hasMath,
|
||||
EditSource: editSource,
|
||||
ReplyFallbackRemoved: dbEvt.LocalContent.GetReplyFallbackRemoved(),
|
||||
}, inlineImages
|
||||
}
|
||||
return nil, nil
|
||||
|
@ -502,14 +517,12 @@ func (h *HiClient) processEvent(
|
|||
contentWithoutFallback := removeReplyFallback(evt)
|
||||
if contentWithoutFallback != nil {
|
||||
dbEvt.Content = contentWithoutFallback
|
||||
dbEvt.MarkReplyFallbackRemoved()
|
||||
}
|
||||
var decryptionErr error
|
||||
var decryptedMautrixEvt *event.Event
|
||||
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
|
||||
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
||||
if decryptionErr != nil {
|
||||
dbEvt.DecryptionError = decryptionErr.Error()
|
||||
}
|
||||
decryptedMautrixEvt, decryptionErr = h.decryptEventInto(ctx, evt, dbEvt)
|
||||
} else if evt.Type == event.EventRedaction {
|
||||
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
|
||||
var err error
|
||||
|
|
Loading…
Add table
Reference in a new issue