Add proxy for providing authenticated download links

This commit is contained in:
Tulir Asokan 2024-07-12 18:52:23 +03:00
parent 3b2f1c79b9
commit bf922e4b1b
3 changed files with 81 additions and 5 deletions

View file

@ -73,7 +73,7 @@ type MatrixContainer interface {
UploadMedia(path string, encrypt bool) (*UploadedMediaInfo, error) UploadMedia(path string, encrypt bool) (*UploadedMediaInfo, error)
Download(uri id.ContentURI, file *attachment.EncryptedFile) ([]byte, error) Download(uri id.ContentURI, file *attachment.EncryptedFile) ([]byte, error)
DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (string, error) DownloadToDisk(uri id.ContentURI, file *attachment.EncryptedFile, target string) (string, error)
GetDownloadURL(uri id.ContentURI) string GetDownloadURL(uri id.ContentURI, file *attachment.EncryptedFile) string
GetCachePath(uri id.ContentURI) string GetCachePath(uri id.ContentURI) string
Crypto() Crypto Crypto() Crypto

View file

@ -26,12 +26,16 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime" "runtime"
dbg "runtime/debug" dbg "runtime/debug"
"strconv"
"strings"
"time" "time"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
@ -63,6 +67,8 @@ type Container struct {
running bool running bool
stop chan bool stop chan bool
mediaProxyURL string
typing int64 typing int64
} }
@ -172,6 +178,12 @@ func (c *Container) InitClient(isStartup bool) error {
} }
} }
if c.mediaProxyURL == "" {
server := httptest.NewServer(http.HandlerFunc(c.doMediaProxy))
c.mediaProxyURL = server.URL
debug.Print("Started media proxy server at", c.mediaProxyURL)
}
c.stop = make(chan bool, 1) c.stop = make(chan bool, 1)
if len(accessToken) > 0 { if len(accessToken) > 0 {
@ -180,6 +192,61 @@ func (c *Container) InitClient(isStartup bool) error {
return nil return nil
} }
func (c *Container) doMediaProxy(w http.ResponseWriter, r *http.Request) {
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(parts) != 2 {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte("Invalid path\n"))
return
}
uri := id.ContentURI{
Homeserver: parts[0],
FileID: parts[1],
}
key := r.URL.Query().Get("k")
iv := r.URL.Query().Get("iv")
hash := r.URL.Query().Get("hash")
var file *attachment.EncryptedFile
if key != "" && iv != "" && hash != "" {
file = &attachment.EncryptedFile{
Key: attachment.JSONWebKey{
Key: key,
Algorithm: "A256CTR",
Extractable: true,
KeyType: "oct",
KeyOps: []string{"encrypt", "decrypt"},
},
InitVector: iv,
Hashes: attachment.EncryptedFileHashes{
SHA256: hash,
},
Version: "v2",
}
}
data, err := c.Download(uri, file)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(fmt.Sprintf("Failed to download media: %v\n", err)))
return
}
mime := http.DetectContentType(data)
w.Header().Add("Content-Length", strconv.Itoa(len(data)))
w.Header().Add("Content-Type", mime)
switch mime {
case "text/css", "text/plain", "text/csv",
"application/json", "application/ld+json",
"image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
"video/mp4", "video/webm", "video/ogg", "video/quicktime",
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
"audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac":
w.Header().Add("Content-Disposition", "inline")
default:
w.Header().Add("Content-Disposition", "attachment")
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
// Initialized returns whether or not the mautrix client is initialized (see InitClient()) // Initialized returns whether or not the mautrix client is initialized (see InitClient())
func (c *Container) Initialized() bool { func (c *Container) Initialized() bool {
return c.client != nil return c.client != nil
@ -1300,8 +1367,17 @@ func (c *Container) Download(uri id.ContentURI, file *attachment.EncryptedFile)
return return
} }
func (c *Container) GetDownloadURL(uri id.ContentURI) string { func (c *Container) GetDownloadURL(uri id.ContentURI, file *attachment.EncryptedFile) string {
return c.client.GetDownloadURL(uri) addr, _ := url.Parse(c.mediaProxyURL)
addr.Path = path.Join(addr.Path, uri.Homeserver, uri.FileID)
if file != nil {
addr.RawQuery = (&url.Values{
"k": {file.Key.Key},
"iv": {file.InitVector},
"hash": {file.Hashes.SHA256},
}).Encode()
}
return addr.String()
} }
func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) { func (c *Container) download(uri id.ContentURI, file *attachment.EncryptedFile, cacheFile string) (data []byte, err error) {

View file

@ -106,7 +106,7 @@ func (msg *FileMessage) NotificationContent() string {
} }
func (msg *FileMessage) PlainText() string { func (msg *FileMessage) PlainText() string {
return fmt.Sprintf("%s: %s", msg.Body, msg.matrix.GetDownloadURL(msg.URL)) return fmt.Sprintf("%s: %s", msg.Body, msg.matrix.GetDownloadURL(msg.URL, msg.File))
} }
func (msg *FileMessage) String() string { func (msg *FileMessage) String() string {
@ -146,7 +146,7 @@ func (msg *FileMessage) CalculateBuffer(prefs config.UserPreferences, width int,
} }
if prefs.BareMessageView || prefs.DisableImages || len(msg.imageData) == 0 { if prefs.BareMessageView || prefs.DisableImages || len(msg.imageData) == 0 {
url := msg.matrix.GetDownloadURL(msg.URL) url := msg.matrix.GetDownloadURL(msg.URL, msg.File)
var urlTString tstring.TString var urlTString tstring.TString
if prefs.EnableInlineURLs() { if prefs.EnableInlineURLs() {
urlTString = tstring.NewStyleTString(url, tcell.StyleDefault.Url(url).UrlId(msg.eventID.String())) urlTString = tstring.NewStyleTString(url, tcell.StyleDefault.Url(url).UrlId(msg.eventID.String()))