From bf922e4b1baf87bdd97d32119d2778bdf83aefbf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 18:52:23 +0300 Subject: [PATCH] Add proxy for providing authenticated download links --- interface/matrix.go | 2 +- matrix/matrix.go | 80 +++++++++++++++++++++++++++++++++++++- ui/messages/filemessage.go | 4 +- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/interface/matrix.go b/interface/matrix.go index d4b2baa..58c428f 100644 --- a/interface/matrix.go +++ b/interface/matrix.go @@ -73,7 +73,7 @@ type MatrixContainer interface { UploadMedia(path string, encrypt bool) (*UploadedMediaInfo, error) Download(uri id.ContentURI, file *attachment.EncryptedFile) ([]byte, error) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (string, error) - GetDownloadURL(uri id.ContentURI) string + GetDownloadURL(uri id.ContentURI, file *attachment.EncryptedFile) string GetCachePath(uri id.ContentURI) string Crypto() Crypto diff --git a/matrix/matrix.go b/matrix/matrix.go index 3a9affe..8ecf145 100644 --- a/matrix/matrix.go +++ b/matrix/matrix.go @@ -26,12 +26,16 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httptest" + "net/url" "os" "path" "path/filepath" "reflect" "runtime" dbg "runtime/debug" + "strconv" + "strings" "time" "maunium.net/go/mautrix" @@ -63,6 +67,8 @@ type Container struct { running bool stop chan bool + mediaProxyURL string + typing int64 } @@ -172,6 +178,12 @@ func (c *Container) InitClient(isStartup bool) error { } } + if c.mediaProxyURL == "" { + server := httptest.NewServer(http.HandlerFunc(c.doMediaProxy)) + c.mediaProxyURL = server.URL + debug.Print("Started media proxy server at", c.mediaProxyURL) + } + c.stop = make(chan bool, 1) if len(accessToken) > 0 { @@ -180,6 +192,61 @@ func (c *Container) InitClient(isStartup bool) error { return nil } +func (c *Container) doMediaProxy(w http.ResponseWriter, r *http.Request) { + parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(parts) != 2 { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("Invalid path\n")) + return + } + uri := id.ContentURI{ + Homeserver: parts[0], + FileID: parts[1], + } + key := r.URL.Query().Get("k") + iv := r.URL.Query().Get("iv") + hash := r.URL.Query().Get("hash") + var file *attachment.EncryptedFile + if key != "" && iv != "" && hash != "" { + file = &attachment.EncryptedFile{ + Key: attachment.JSONWebKey{ + Key: key, + Algorithm: "A256CTR", + Extractable: true, + KeyType: "oct", + KeyOps: []string{"encrypt", "decrypt"}, + }, + InitVector: iv, + Hashes: attachment.EncryptedFileHashes{ + SHA256: hash, + }, + Version: "v2", + } + } + data, err := c.Download(uri, file) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(fmt.Sprintf("Failed to download media: %v\n", err))) + return + } + mime := http.DetectContentType(data) + w.Header().Add("Content-Length", strconv.Itoa(len(data))) + w.Header().Add("Content-Type", mime) + switch mime { + case "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac": + w.Header().Add("Content-Disposition", "inline") + default: + w.Header().Add("Content-Disposition", "attachment") + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) +} + // Initialized returns whether or not the mautrix client is initialized (see InitClient()) func (c *Container) Initialized() bool { return c.client != nil @@ -1300,8 +1367,17 @@ func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile) return } -func (c *Container) GetDownloadURL(uri id.ContentURI) string { - return c.client.GetDownloadURL(uri) +func (c *Container) GetDownloadURL(uri id.ContentURI, file *attachment.EncryptedFile) string { + addr, _ := url.Parse(c.mediaProxyURL) + addr.Path = path.Join(addr.Path, uri.Homeserver, uri.FileID) + if file != nil { + addr.RawQuery = (&url.Values{ + "k": {file.Key.Key}, + "iv": {file.InitVector}, + "hash": {file.Hashes.SHA256}, + }).Encode() + } + return addr.String() } func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) { diff --git a/ui/messages/filemessage.go b/ui/messages/filemessage.go index da5f50e..abc0b52 100644 --- a/ui/messages/filemessage.go +++ b/ui/messages/filemessage.go @@ -106,7 +106,7 @@ func (msg *FileMessage) NotificationContent() string { } func (msg *FileMessage) PlainText() string { - return fmt.Sprintf("%s: %s", msg.Body, msg.matrix.GetDownloadURL(msg.URL)) + return fmt.Sprintf("%s: %s", msg.Body, msg.matrix.GetDownloadURL(msg.URL, msg.File)) } func (msg *FileMessage) String() string { @@ -146,7 +146,7 @@ func (msg *FileMessage) CalculateBuffer(prefs config.UserPreferences, width int, } if prefs.BareMessageView || prefs.DisableImages || len(msg.imageData) == 0 { - url := msg.matrix.GetDownloadURL(msg.URL) + url := msg.matrix.GetDownloadURL(msg.URL, msg.File) var urlTString tstring.TString if prefs.EnableInlineURLs() { urlTString = tstring.NewStyleTString(url, tcell.StyleDefault.Url(url).UrlId(msg.eventID.String()))