hicli/html: add inline images to media references

This commit is contained in:
Tulir Asokan 2024-10-20 13:31:33 +03:00
parent 2c5738f7f2
commit 535393d47f
3 changed files with 55 additions and 21 deletions

View file

@ -330,7 +330,7 @@ func matrixURIClassName(uri *id.MatrixURI) string {
} }
} }
func writeA(w *strings.Builder, attr []html.Attribute) { func writeA(w *strings.Builder, attr []html.Attribute) (mxc id.ContentURI) {
w.WriteString("<a") w.WriteString("<a")
href := parseAAttributes(attr) href := parseAAttributes(attr)
if href == "" { if href == "" {
@ -365,8 +365,9 @@ func writeA(w *strings.Builder, attr []html.Attribute) {
newTab = false newTab = false
writeAttribute(w, "class", matrixURIClassName(uri)) writeAttribute(w, "class", matrixURIClassName(uri))
case "mxc": case "mxc":
mxc := id.ContentURIString(href).ParseOrIgnore() mxc = id.ContentURIString(href).ParseOrIgnore()
if !mxc.IsValid() { if !mxc.IsValid() {
mxc = id.ContentURI{}
return return
} }
href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID) href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)
@ -378,11 +379,12 @@ func writeA(w *strings.Builder, attr []html.Attribute) {
writeAttribute(w, "target", "_blank") writeAttribute(w, "target", "_blank")
writeAttribute(w, "rel", "noreferrer noopener") writeAttribute(w, "rel", "noreferrer noopener")
} }
return
} }
var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s" var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s"
func writeImg(w *strings.Builder, attr []html.Attribute) { func writeImg(w *strings.Builder, attr []html.Attribute) id.ContentURI {
src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr) src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr)
w.WriteString("<img") w.WriteString("<img")
writeAttribute(w, "alt", alt) writeAttribute(w, "alt", alt)
@ -391,7 +393,7 @@ func writeImg(w *strings.Builder, attr []html.Attribute) {
} }
mxc := id.ContentURIString(src).ParseOrIgnore() mxc := id.ContentURIString(src).ParseOrIgnore()
if !mxc.IsValid() { if !mxc.IsValid() {
return return id.ContentURI{}
} }
writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)) writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID))
writeAttribute(w, "loading", "lazy") writeAttribute(w, "loading", "lazy")
@ -403,6 +405,7 @@ func writeImg(w *strings.Builder, attr []html.Attribute) {
} else { } else {
writeAttribute(w, "class", "hicli-sizeless-inline-img") writeAttribute(w, "class", "hicli-sizeless-inline-img")
} }
return mxc
} }
func writeSpan(w *strings.Builder, attr []html.Attribute) { func writeSpan(w *strings.Builder, attr []html.Attribute) {
@ -454,9 +457,10 @@ func (ts *tagStack) pop(tag atom.Atom) bool {
return false return false
} }
func sanitizeAndLinkifyHTML(body string) (string, error) { func sanitizeAndLinkifyHTML(body string) (string, []id.ContentURI, error) {
tz := html.NewTokenizer(strings.NewReader(body)) tz := html.NewTokenizer(strings.NewReader(body))
var built strings.Builder var built strings.Builder
var inlineImages []id.ContentURI
ts := make(tagStack, 2) ts := make(tagStack, 2)
Loop: Loop:
for { for {
@ -466,7 +470,7 @@ Loop:
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break Loop break Loop
} }
return "", err return "", nil, err
case html.StartTagToken, html.SelfClosingTagToken: case html.StartTagToken, html.SelfClosingTagToken:
token := tz.Token() token := tz.Token()
if !tagIsAllowed(token.DataAtom) { if !tagIsAllowed(token.DataAtom) {
@ -478,9 +482,15 @@ Loop:
} }
switch token.DataAtom { switch token.DataAtom {
case atom.A: case atom.A:
writeA(&built, token.Attr) mxc := writeA(&built, token.Attr)
if !mxc.IsEmpty() {
inlineImages = append(inlineImages, mxc)
}
case atom.Img: case atom.Img:
writeImg(&built, token.Attr) mxc := writeImg(&built, token.Attr)
if !mxc.IsEmpty() {
inlineImages = append(inlineImages, mxc)
}
case atom.Span, atom.Font: case atom.Span, atom.Font:
writeSpan(&built, token.Attr) writeSpan(&built, token.Attr)
default: default:
@ -524,5 +534,5 @@ Loop:
built.WriteString(t.String()) built.WriteString(t.String())
built.WriteByte('>') built.WriteByte('>')
} }
return built.String(), nil return built.String(), inlineImages, nil
} }

View file

@ -137,11 +137,17 @@ func (h *HiClient) Send(
} }
dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content) dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content)
} }
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix()) var inlineImages []id.ContentURI
mautrixEvt := dbEvt.AsRawMautrix()
dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, mautrixEvt)
_, err = h.DB.Event.Insert(ctx, dbEvt) _, err = h.DB.Event.Insert(ctx, dbEvt)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to insert event into database: %w", err) return nil, fmt.Errorf("failed to insert event into database: %w", err)
} }
h.cacheMedia(ctx, mautrixEvt, dbEvt.RowID)
for _, uri := range inlineImages {
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
}
ctx = context.WithoutCancel(ctx) ctx = context.WithoutCancel(ctx)
go func() { go func() {
err := h.SetTyping(ctx, room.ID, 0) err := h.SetTyping(ctx, room.ID, 0)

View file

@ -343,14 +343,14 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab
} }
} }
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) (*database.LocalContent, []id.ContentURI) {
if evt.Type != event.EventMessage { if evt.Type != event.EventMessage {
return nil return nil, nil
} }
_ = evt.Content.ParseRaw(evt.Type) _ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent) content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok { if !ok {
return nil return nil, nil
} }
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
content = content.NewContent content = content.NewContent
@ -358,8 +358,21 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
if content != nil { if content != nil {
var sanitizedHTML string var sanitizedHTML string
var wasPlaintext bool var wasPlaintext bool
var inlineImages []id.ContentURI
if content.Format == event.FormatHTML && content.FormattedBody != "" { if content.Format == event.FormatHTML && content.FormattedBody != "" {
sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) var err error
sanitizedHTML, inlineImages, err = sanitizeAndLinkifyHTML(content.FormattedBody)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_id", dbEvt.ID).
Msg("Failed to sanitize HTML")
}
if len(inlineImages) > 0 && dbEvt.RowID != 0 {
for _, uri := range inlineImages {
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
}
inlineImages = nil
}
} else { } else {
var builder strings.Builder var builder strings.Builder
linkifyAndWriteBytes(&builder, []byte(content.Body)) linkifyAndWriteBytes(&builder, []byte(content.Body))
@ -370,9 +383,9 @@ func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Ev
SanitizedHTML: sanitizedHTML, SanitizedHTML: sanitizedHTML,
HTMLVersion: CurrentHTMLSanitizerVersion, HTMLVersion: CurrentHTMLSanitizerVersion,
WasPlaintext: wasPlaintext, WasPlaintext: wasPlaintext,
}, inlineImages
} }
} return nil, nil
return nil
} }
const CurrentHTMLSanitizerVersion = 2 const CurrentHTMLSanitizerVersion = 2
@ -381,7 +394,7 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve
if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion { if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
return return
} }
evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix()) evt.LocalContent, _ = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
err := h.DB.Event.UpdateLocalContent(ctx, evt) err := h.DB.Event.UpdateLocalContent(ctx, evt)
if err != nil { if err != nil {
zerolog.Ctx(ctx).Err(err). zerolog.Ctx(ctx).Err(err).
@ -390,14 +403,15 @@ func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Eve
} }
} }
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) (inlineImages []id.ContentURI) {
if dbEvt.RowID != 0 { if dbEvt.RowID != 0 {
h.cacheMedia(ctx, evt, dbEvt.RowID) h.cacheMedia(ctx, evt, dbEvt.RowID)
} }
if evt.Sender != h.Account.UserID { if evt.Sender != h.Account.UserID {
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
} }
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) dbEvt.LocalContent, inlineImages = h.calculateLocalContent(ctx, dbEvt, evt)
return
} }
func (h *HiClient) processEvent( func (h *HiClient) processEvent(
@ -438,10 +452,11 @@ func (h *HiClient) processEvent(
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
} }
} }
var inlineImages []id.ContentURI
if decryptedMautrixEvt != nil { if decryptedMautrixEvt != nil {
h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
} else { } else {
h.postDecryptProcess(ctx, llSummary, dbEvt, evt) inlineImages = h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
} }
_, err := h.DB.Event.Upsert(ctx, dbEvt) _, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil { if err != nil {
@ -452,6 +467,9 @@ func (h *HiClient) processEvent(
} else { } else {
h.cacheMedia(ctx, evt, dbEvt.RowID) h.cacheMedia(ctx, evt, dbEvt.RowID)
} }
for _, uri := range inlineImages {
h.addMediaCache(ctx, dbEvt.RowID, uri.CUString(), nil, nil, "")
}
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID] req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok { if !ok {