1
0
Fork 0
forked from Mirrors/gomuks
nyxmuks/media.go
2024-10-08 23:47:59 +03:00

224 lines
6.8 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"
"github.com/rs/zerolog"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
)
var ErrBadGateway = mautrix.RespError{
ErrCode: "FI.MAU.GOMUKS.BAD_GATEWAY",
StatusCode: http.StatusBadGateway,
}
func (gmx *Gomuks) downloadMediaFromCache(ctx context.Context, w http.ResponseWriter, entry *database.CachedMedia, force bool) bool {
if !entry.UseCache() {
if force {
mautrix.MNotFound.WithMessage("Media not found in cache").Write(w)
return true
}
return false
}
if entry.Error != nil {
w.Header().Set("Mau-Cached-Error", "true")
entry.Error.Write(w)
return true
}
log := zerolog.Ctx(ctx)
cacheFile, err := os.Open(gmx.cacheEntryToPath(entry))
if err != nil {
if errors.Is(err, os.ErrNotExist) && !force {
return false
}
log.Err(err).Msg("Failed to open cache file")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to open cache file: %v", err)).Write(w)
return true
}
defer func() {
_ = cacheFile.Close()
}()
cacheEntryToHeaders(w, entry)
w.WriteHeader(http.StatusOK)
_, err = io.Copy(w, cacheFile)
if err != nil {
log.Err(err).Msg("Failed to copy cache file to response")
}
return true
}
func (gmx *Gomuks) cacheEntryToPath(entry *database.CachedMedia) string {
hashPath := hex.EncodeToString(entry.Hash[:])
return filepath.Join(gmx.CacheDir, "media", hashPath[0:2], hashPath[2:4], hashPath[4:])
}
func cacheEntryToHeaders(w http.ResponseWriter, entry *database.CachedMedia) {
w.Header().Set("Content-Type", entry.MimeType)
w.Header().Set("Content-Length", strconv.FormatInt(entry.Size, 10))
w.Header().Set("Content-Disposition", mime.FormatMediaType(entry.ContentDisposition(), map[string]string{"filename": entry.FileName}))
w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';")
}
func (gmx *Gomuks) DownloadMedia(w http.ResponseWriter, r *http.Request) {
mxc := id.ContentURI{
Homeserver: r.PathValue("server"),
FileID: r.PathValue("media_id"),
}
if !mxc.IsValid() {
mautrix.MInvalidParam.WithMessage("Invalid mxc URI").Write(w)
return
}
query := r.URL.Query()
encrypted, _ := strconv.ParseBool(query.Get("encrypted"))
logVal := zerolog.Ctx(r.Context()).With().
Stringer("mxc_uri", mxc).
Bool("encrypted", encrypted).
Logger()
log := &logVal
ctx := log.WithContext(r.Context())
cacheEntry, err := gmx.Client.DB.CachedMedia.Get(ctx, mxc)
if err != nil {
log.Err(err).Msg("Failed to get cached media entry")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to get cached media entry: %v", err)).Write(w)
return
} else if (cacheEntry == nil || cacheEntry.EncFile == nil) && encrypted {
mautrix.MNotFound.WithMessage("Media encryption keys not found in cache").Write(w)
return
}
if gmx.downloadMediaFromCache(ctx, w, cacheEntry, false) {
return
}
tempFile, err := os.CreateTemp("", "gomuks-download-*")
if err != nil {
log.Err(err).Msg("Failed to create temporary file")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create temp file: %v", err)).Write(w)
return
}
defer func() {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
}()
resp, err := gmx.Client.Client.Download(ctx, mxc)
if err != nil {
log.Err(err).Msg("Failed to download media")
var httpErr mautrix.HTTPError
if cacheEntry == nil {
cacheEntry = &database.CachedMedia{
MXC: mxc,
}
}
if cacheEntry.Error == nil {
cacheEntry.Error = &database.MediaError{
ReceivedAt: jsontime.UnixMilliNow(),
}
} else {
cacheEntry.Error.Attempts++
cacheEntry.Error.ReceivedAt = jsontime.UnixMilliNow()
}
if errors.As(err, &httpErr) {
if httpErr.WrappedError != nil {
cacheEntry.Error.Matrix = ptr.Ptr(ErrBadGateway.WithMessage(httpErr.WrappedError.Error()))
cacheEntry.Error.StatusCode = http.StatusBadGateway
} else if httpErr.RespError != nil {
cacheEntry.Error.Matrix = httpErr.RespError
cacheEntry.Error.StatusCode = httpErr.Response.StatusCode
} else {
cacheEntry.Error.Matrix = ptr.Ptr(mautrix.MUnknown.WithMessage("Server returned non-JSON error with status %d", httpErr.Response.StatusCode))
cacheEntry.Error.StatusCode = httpErr.Response.StatusCode
}
} else {
cacheEntry.Error.Matrix = ptr.Ptr(ErrBadGateway.WithMessage(err.Error()))
cacheEntry.Error.StatusCode = http.StatusBadGateway
}
err = gmx.Client.DB.CachedMedia.Put(ctx, cacheEntry)
if err != nil {
log.Err(err).Msg("Failed to save errored cache entry")
}
cacheEntry.Error.Write(w)
return
}
defer func() {
_ = resp.Body.Close()
}()
if cacheEntry == nil {
cacheEntry = &database.CachedMedia{
MXC: mxc,
MimeType: resp.Header.Get("Content-Type"),
Size: resp.ContentLength,
}
}
reader := resp.Body
if cacheEntry.EncFile != nil {
err = cacheEntry.EncFile.PrepareForDecryption()
if err != nil {
log.Err(err).Msg("Failed to prepare media for decryption")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to prepare media for decryption: %v", err)).Write(w)
return
}
reader = cacheEntry.EncFile.DecryptStream(reader)
}
fileHasher := sha256.New()
hashReader := io.TeeReader(reader, fileHasher)
cacheEntry.Size, err = io.Copy(tempFile, hashReader)
if err != nil {
log.Err(err).Msg("Failed to copy media to temporary file")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to copy media to temp file: %v", err)).Write(w)
return
}
err = reader.Close()
if err != nil {
log.Err(err).Msg("Failed to close media reader")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to finish reading media: %v", err)).Write(w)
return
}
_ = tempFile.Close()
if cacheEntry.FileName == "" {
_, params, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition"))
cacheEntry.FileName = params["filename"]
}
if cacheEntry.MimeType == "" {
cacheEntry.MimeType = resp.Header.Get("Content-Type")
}
cacheEntry.Hash = (*[32]byte)(fileHasher.Sum(nil))
cacheEntry.Error = nil
err = gmx.Client.DB.CachedMedia.Put(ctx, cacheEntry)
if err != nil {
log.Err(err).Msg("Failed to save cache entry")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to save cache entry: %v", err)).Write(w)
return
}
cachePath := gmx.cacheEntryToPath(cacheEntry)
err = os.MkdirAll(filepath.Dir(cachePath), 0700)
if err != nil {
log.Err(err).Msg("Failed to create cache directory")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create cache directory: %v", err)).Write(w)
return
}
err = os.Rename(tempFile.Name(), cachePath)
if err != nil {
log.Err(err).Msg("Failed to rename temporary file")
mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to rename temp file: %v", err)).Write(w)
return
}
gmx.downloadMediaFromCache(ctx, w, cacheEntry, true)
}