diff --git a/pkg/hicli/html.go b/pkg/hicli/html.go
index af92864..bcc9fbb 100644
--- a/pkg/hicli/html.go
+++ b/pkg/hicli/html.go
@@ -330,7 +330,7 @@ func matrixURIClassName(uri *id.MatrixURI) string {
}
}
-func writeA(w *strings.Builder, attr []html.Attribute) {
+func writeA(w *strings.Builder, attr []html.Attribute) (mxc id.ContentURI) {
w.WriteString("')
}
- return built.String(), nil
+ return built.String(), inlineImages, nil
}
diff --git a/pkg/hicli/send.go b/pkg/hicli/send.go
index 42a8a1e..058601e 100644
--- a/pkg/hicli/send.go
+++ b/pkg/hicli/send.go
@@ -137,11 +137,17 @@ func (h *HiClient) Send(
}
dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content)
}
- dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix())
+ var inlineImages []id.ContentURI
+ mautrixEvt := dbEvt.AsRawMautrix()
+ dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, mautrixEvt)
_, err = h.DB.Event.Insert(ctx, dbEvt)
if err != nil {
return nil, fmt.Errorf("failed to insert event into database: %w", err)
}
+ h.cacheMedia(ctx, mautrixEvt, dbEvt.RowID)
+ for _, uri := range inlineImages {
+ h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
+ }
ctx = context.WithoutCancel(ctx)
go func() {
err := h.SetTyping(ctx, room.ID, 0)
diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go
index d0fbda2..36c59b3 100644
--- a/pkg/hicli/sync.go
+++ b/pkg/hicli/sync.go
@@ -343,14 +343,14 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab
}
}
-func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
+func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) {
if evt.Type != event.EventMessage {
- return nil
+ return nil, nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
- return nil
+ return nil, nil
}
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
content = content.NewContent
@@ -358,8 +358,21 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
if content != nil {
var sanitizedHTML string
var wasPlaintext bool
+ var inlineImages []id.ContentURI
if content.Format == event.FormatHTML && content.FormattedBody != "" {
- sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody)
+ var err error
+ sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody)
+ if err != nil {
+ zerolog.Ctx(ctx).Warn().Err(err).
+ Stringer("event_id", dbEvt.ID).
+ Msg("Failed to sanitize HTML")
+ }
+ if len(inlineImages) > 0 && dbEvt.RowID != 0 {
+ for _, uri := range inlineImages {
+ h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
+ }
+ inlineImages = nil
+ }
} else {
var builder strings.Builder
linkifyAndWriteBytes(&builder, []byte(content.Body))
@@ -370,9 +383,9 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
SanitizedHTML: sanitizedHTML,
HTMLVersion: CurrentHTMLSanitizerVersion,
WasPlaintext: wasPlaintext,
- }
+ }, inlineImages
}
- return nil
+ return nil, nil
}
const CurrentHTMLSanitizerVersion = 2
@@ -381,7 +394,7 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve
if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
return
}
- evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
+ evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
err := h.DB.Event.UpdateLocalContent(ctx, evt)
if err != nil {
zerolog.Ctx(ctx).Err(err).
@@ -390,14 +403,15 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve
}
}
-func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) {
+func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) {
if dbEvt.RowID != 0 {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
if evt.Sender != h.Account.UserID {
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
}
- dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt)
+ dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt)
+ return
}
func (h *HiClient) processEvent(
@@ -438,10 +452,11 @@ func (h *HiClient) processEvent(
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
}
}
+ var inlineImages []id.ContentURI
if decryptedMautrixEvt != nil {
- h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
+ inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
} else {
- h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
+ inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
}
_, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil {
@@ -452,6 +467,9 @@ func (h *HiClient) processEvent(
} else {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
+ for _, uri := range inlineImages {
+ h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
+ }
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok {