forked from Mirrors/gomuks
224 lines
6.8 KiB
Go
224 lines
6.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
|
|
"github.com/rs/zerolog"
|
|
"go.mau.fi/util/jsontime"
|
|
"go.mau.fi/util/ptr"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/hicli/database"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
var ErrBadGateway = mautrix.RespError{
|
|
ErrCode: "FI.MAU.GOMUKS.BAD_GATEWAY",
|
|
StatusCode: http.StatusBadGateway,
|
|
}
|
|
|
|
func (gmx *Gomuks) downloadMediaFromCache(ctx context.Context, w http.ResponseWriter, entry *database.CachedMedia, force bool) bool {
|
|
if !entry.UseCache() {
|
|
if force {
|
|
mautrix.MNotFound.WithMessage("Media not found in cache").Write(w)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
if entry.Error != nil {
|
|
w.Header().Set("Mau-Cached-Error", "true")
|
|
entry.Error.Write(w)
|
|
return true
|
|
}
|
|
log := zerolog.Ctx(ctx)
|
|
cacheFile, err := os.Open(gmx.cacheEntryToPath(entry))
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) && !force {
|
|
return false
|
|
}
|
|
log.Err(err).Msg("Failed to open cache file")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to open cache file: %v", err)).Write(w)
|
|
return true
|
|
}
|
|
defer func() {
|
|
_ = cacheFile.Close()
|
|
}()
|
|
cacheEntryToHeaders(w, entry)
|
|
w.WriteHeader(http.StatusOK)
|
|
_, err = io.Copy(w, cacheFile)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to copy cache file to response")
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (gmx *Gomuks) cacheEntryToPath(entry *database.CachedMedia) string {
|
|
hashPath := hex.EncodeToString(entry.Hash[:])
|
|
return filepath.Join(gmx.CacheDir, "media", hashPath[0:2], hashPath[2:4], hashPath[4:])
|
|
}
|
|
|
|
func cacheEntryToHeaders(w http.ResponseWriter, entry *database.CachedMedia) {
|
|
w.Header().Set("Content-Type", entry.MimeType)
|
|
w.Header().Set("Content-Length", strconv.FormatInt(entry.Size, 10))
|
|
w.Header().Set("Content-Disposition", mime.FormatMediaType(entry.ContentDisposition(), map[string]string{"filename": entry.FileName}))
|
|
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';")
|
|
}
|
|
|
|
func (gmx *Gomuks) DownloadMedia(w http.ResponseWriter, r *http.Request) {
|
|
mxc := id.ContentURI{
|
|
Homeserver: r.PathValue("server"),
|
|
FileID: r.PathValue("media_id"),
|
|
}
|
|
if !mxc.IsValid() {
|
|
mautrix.MInvalidParam.WithMessage("Invalid mxc URI").Write(w)
|
|
return
|
|
}
|
|
query := r.URL.Query()
|
|
encrypted, _ := strconv.ParseBool(query.Get("encrypted"))
|
|
|
|
logVal := zerolog.Ctx(r.Context()).With().
|
|
Stringer("mxc_uri", mxc).
|
|
Bool("encrypted", encrypted).
|
|
Logger()
|
|
log := &logVal
|
|
ctx := log.WithContext(r.Context())
|
|
cacheEntry, err := gmx.Client.DB.CachedMedia.Get(ctx, mxc)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to get cached media entry")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to get cached media entry: %v", err)).Write(w)
|
|
return
|
|
} else if (cacheEntry == nil || cacheEntry.EncFile == nil) && encrypted {
|
|
mautrix.MNotFound.WithMessage("Media encryption keys not found in cache").Write(w)
|
|
return
|
|
}
|
|
|
|
if gmx.downloadMediaFromCache(ctx, w, cacheEntry, false) {
|
|
return
|
|
}
|
|
|
|
tempFile, err := os.CreateTemp("", "gomuks-download-*")
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create temporary file")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create temp file: %v", err)).Write(w)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = tempFile.Close()
|
|
_ = os.Remove(tempFile.Name())
|
|
}()
|
|
|
|
resp, err := gmx.Client.Client.Download(ctx, mxc)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to download media")
|
|
var httpErr mautrix.HTTPError
|
|
if cacheEntry == nil {
|
|
cacheEntry = &database.CachedMedia{
|
|
MXC: mxc,
|
|
}
|
|
}
|
|
if cacheEntry.Error == nil {
|
|
cacheEntry.Error = &database.MediaError{
|
|
ReceivedAt: jsontime.UnixMilliNow(),
|
|
}
|
|
} else {
|
|
cacheEntry.Error.Attempts++
|
|
cacheEntry.Error.ReceivedAt = jsontime.UnixMilliNow()
|
|
}
|
|
if errors.As(err, &httpErr) {
|
|
if httpErr.WrappedError != nil {
|
|
cacheEntry.Error.Matrix = ptr.Ptr(ErrBadGateway.WithMessage(httpErr.WrappedError.Error()))
|
|
cacheEntry.Error.StatusCode = http.StatusBadGateway
|
|
} else if httpErr.RespError != nil {
|
|
cacheEntry.Error.Matrix = httpErr.RespError
|
|
cacheEntry.Error.StatusCode = httpErr.Response.StatusCode
|
|
} else {
|
|
cacheEntry.Error.Matrix = ptr.Ptr(mautrix.MUnknown.WithMessage("Server returned non-JSON error with status %d", httpErr.Response.StatusCode))
|
|
cacheEntry.Error.StatusCode = httpErr.Response.StatusCode
|
|
}
|
|
} else {
|
|
cacheEntry.Error.Matrix = ptr.Ptr(ErrBadGateway.WithMessage(err.Error()))
|
|
cacheEntry.Error.StatusCode = http.StatusBadGateway
|
|
}
|
|
err = gmx.Client.DB.CachedMedia.Put(ctx, cacheEntry)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to save errored cache entry")
|
|
}
|
|
cacheEntry.Error.Write(w)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
if cacheEntry == nil {
|
|
cacheEntry = &database.CachedMedia{
|
|
MXC: mxc,
|
|
MimeType: resp.Header.Get("Content-Type"),
|
|
Size: resp.ContentLength,
|
|
}
|
|
}
|
|
|
|
reader := resp.Body
|
|
if cacheEntry.EncFile != nil {
|
|
err = cacheEntry.EncFile.PrepareForDecryption()
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to prepare media for decryption")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to prepare media for decryption: %v", err)).Write(w)
|
|
return
|
|
}
|
|
reader = cacheEntry.EncFile.DecryptStream(reader)
|
|
}
|
|
fileHasher := sha256.New()
|
|
hashReader := io.TeeReader(reader, fileHasher)
|
|
cacheEntry.Size, err = io.Copy(tempFile, hashReader)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to copy media to temporary file")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to copy media to temp file: %v", err)).Write(w)
|
|
return
|
|
}
|
|
err = reader.Close()
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to close media reader")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to finish reading media: %v", err)).Write(w)
|
|
return
|
|
}
|
|
_ = tempFile.Close()
|
|
if cacheEntry.FileName == "" {
|
|
_, params, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
|
|
cacheEntry.FileName = params["filename"]
|
|
}
|
|
if cacheEntry.MimeType == "" {
|
|
cacheEntry.MimeType = resp.Header.Get("Content-Type")
|
|
}
|
|
cacheEntry.Hash = (*[32]byte)(fileHasher.Sum(nil))
|
|
cacheEntry.Error = nil
|
|
err = gmx.Client.DB.CachedMedia.Put(ctx, cacheEntry)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to save cache entry")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to save cache entry: %v", err)).Write(w)
|
|
return
|
|
}
|
|
cachePath := gmx.cacheEntryToPath(cacheEntry)
|
|
err = os.MkdirAll(filepath.Dir(cachePath), 0700)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to create cache directory")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create cache directory: %v", err)).Write(w)
|
|
return
|
|
}
|
|
err = os.Rename(tempFile.Name(), cachePath)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to rename temporary file")
|
|
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to rename temp file: %v", err)).Write(w)
|
|
return
|
|
}
|
|
gmx.downloadMediaFromCache(ctx, w, cacheEntry, true)
|
|
}
|