diff --git a/pkg/hicli/html.go b/pkg/hicli/html.go index af92864..bcc9fbb 100644 --- a/pkg/hicli/html.go +++ b/pkg/hicli/html.go @@ -330,7 +330,7 @@ func matrixURIClassName(uri *id.MatrixURI) string { } } -func writeA(w *strings.Builder, attr []html.Attribute) { +func writeA(w *strings.Builder, attr []html.Attribute) (mxc id.ContentURI) { w.WriteString("') } - return built.String(), nil + return built.String(), inlineImages, nil } diff --git a/pkg/hicli/send.go b/pkg/hicli/send.go index 42a8a1e..058601e 100644 --- a/pkg/hicli/send.go +++ b/pkg/hicli/send.go @@ -137,11 +137,17 @@ func (h *HiClient) Send( } dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content) } - dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix()) + var inlineImages []id.ContentURI + mautrixEvt := dbEvt.AsRawMautrix() + dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, mautrixEvt) _, err = h.DB.Event.Insert(ctx, dbEvt) if err != nil { return nil, fmt.Errorf("failed to insert event into database: %w", err) } + h.cacheMedia(ctx, mautrixEvt, dbEvt.RowID) + for _, uri := range inlineImages { + h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "") + } ctx = context.WithoutCancel(ctx) go func() { err := h.SetTyping(ctx, room.ID, 0) diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index d0fbda2..36c59b3 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -343,14 +343,14 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab } } -func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { +func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) { if evt.Type != event.EventMessage { - return nil + return nil, nil } _ = evt.Content.ParseRaw(evt.Type) content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { - return nil + return nil, nil } if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { content = content.NewContent @@ -358,8 +358,21 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev if content != nil { var sanitizedHTML string var wasPlaintext bool + var inlineImages []id.ContentURI if content.Format == event.FormatHTML && content.FormattedBody != "" { - sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) + var err error + sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_id", dbEvt.ID). + Msg("Failed to sanitize HTML") + } + if len(inlineImages) > 0 && dbEvt.RowID != 0 { + for _, uri := range inlineImages { + h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "") + } + inlineImages = nil + } } else { var builder strings.Builder linkifyAndWriteBytes(&builder, []byte(content.Body)) @@ -370,9 +383,9 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev SanitizedHTML: sanitizedHTML, HTMLVersion: CurrentHTMLSanitizerVersion, WasPlaintext: wasPlaintext, - } + }, inlineImages } - return nil + return nil, nil } const CurrentHTMLSanitizerVersion = 2 @@ -381,7 +394,7 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion { return } - evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix()) + evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix()) err := h.DB.Event.UpdateLocalContent(ctx, evt) if err != nil { zerolog.Ctx(ctx).Err(err). @@ -390,14 +403,15 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve } } -func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { +func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) { if dbEvt.RowID != 0 { h.cacheMedia(ctx, evt, dbEvt.RowID) } if evt.Sender != h.Account.UserID { dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) } - dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) + dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt) + return } func (h *HiClient) processEvent( @@ -438,10 +452,11 @@ func (h *HiClient) processEvent( evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) } } + var inlineImages []id.ContentURI if decryptedMautrixEvt != nil { - h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) + inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) } else { - h.postDecryptProcess(ctx, llSummary, dbEvt, evt) + inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt) } _, err := h.DB.Event.Upsert(ctx, dbEvt) if err != nil { @@ -452,6 +467,9 @@ func (h *HiClient) processEvent( } else { h.cacheMedia(ctx, evt, dbEvt.RowID) } + for _, uri := range inlineImages { + h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "") + } if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { req, ok := decryptionQueue[dbEvt.MegolmSessionID] if !ok {