package main import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "mime" "net/http" "os" "path/filepath" "strconv" "github.com/rs/zerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) var ErrBadGateway = mautrix.RespError{ ErrCode: "FI.MAU.GOMUKS.BAD_GATEWAY", StatusCode: http.StatusBadGateway, } func (gmx *Gomuks) downloadMediaFromCache(ctx context.Context, w http.ResponseWriter, entry *database.CachedMedia, force bool) bool { if entry == nil || entry.Hash == nil { if force { mautrix.MNotFound.WithMessage("Media not found in cache").Write(w) return true } return false } log := zerolog.Ctx(ctx) cacheFile, err := os.Open(gmx.cacheEntryToPath(entry)) if err != nil { if errors.Is(err, os.ErrNotExist) && !force { return false } log.Err(err).Msg("Failed to open cache file") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to open cache file: %v", err)).Write(w) return true } defer func() { _ = cacheFile.Close() }() cacheEntryToHeaders(w, entry) w.WriteHeader(http.StatusOK) _, err = io.Copy(w, cacheFile) if err != nil { log.Err(err).Msg("Failed to copy cache file to response") } return true } func (gmx *Gomuks) cacheEntryToPath(entry *database.CachedMedia) string { hashPath := hex.EncodeToString(entry.Hash[:]) return filepath.Join(gmx.CacheDir, "media", hashPath[0:2], hashPath[2:4], hashPath[4:]) } func cacheEntryToHeaders(w http.ResponseWriter, entry *database.CachedMedia) { w.Header().Set("Content-Type", entry.MimeType) w.Header().Set("Content-Length", strconv.FormatInt(entry.Size, 10)) w.Header().Set("Content-Disposition", mime.FormatMediaType(entry.ContentDisposition(), map[string]string{"filename": entry.FileName})) w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none';") } func (gmx *Gomuks) DownloadMedia(w http.ResponseWriter, r *http.Request) { mxc := id.ContentURI{ Homeserver: r.PathValue("server"), FileID: r.PathValue("media_id"), } if !mxc.IsValid() { mautrix.MInvalidParam.WithMessage("Invalid mxc URI").Write(w) return } query := r.URL.Query() encrypted, _ := strconv.ParseBool(query.Get("encrypted")) logVal := zerolog.Ctx(r.Context()).With(). Stringer("mxc_uri", mxc). Bool("encrypted", encrypted). Logger() log := &logVal ctx := log.WithContext(r.Context()) cacheEntry, err := gmx.Client.DB.CachedMedia.Get(ctx, mxc) if err != nil { log.Err(err).Msg("Failed to get cached media entry") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to get cached media entry: %v", err)).Write(w) return } else if (cacheEntry == nil || cacheEntry.EncFile == nil) && encrypted { mautrix.MNotFound.WithMessage("Media encryption keys not found in cache").Write(w) return } if gmx.downloadMediaFromCache(ctx, w, cacheEntry, false) { return } tempFile, err := os.CreateTemp("", "gomuks-download-*") if err != nil { log.Err(err).Msg("Failed to create temporary file") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create temp file: %v", err)).Write(w) return } defer func() { _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() resp, err := gmx.Client.Client.Download(ctx, mxc) if err != nil { log.Err(err).Msg("Failed to download media") ErrBadGateway.WithMessage(err.Error()).Write(w) return } defer func() { _ = resp.Body.Close() }() if cacheEntry == nil { cacheEntry = &database.CachedMedia{ MXC: mxc, MimeType: resp.Header.Get("Content-Type"), Size: resp.ContentLength, } } reader := resp.Body if cacheEntry.EncFile != nil { err = cacheEntry.EncFile.PrepareForDecryption() if err != nil { log.Err(err).Msg("Failed to prepare media for decryption") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to prepare media for decryption: %v", err)).Write(w) return } reader = cacheEntry.EncFile.DecryptStream(reader) } fileHasher := sha256.New() hashReader := io.TeeReader(reader, fileHasher) cacheEntry.Size, err = io.Copy(tempFile, hashReader) if err != nil { log.Err(err).Msg("Failed to copy media to temporary file") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to copy media to temp file: %v", err)).Write(w) return } err = reader.Close() if err != nil { log.Err(err).Msg("Failed to close media reader") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to finish reading media: %v", err)).Write(w) return } _ = tempFile.Close() if cacheEntry.FileName == "" { _, params, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) cacheEntry.FileName = params["filename"] } if cacheEntry.MimeType == "" { cacheEntry.MimeType = resp.Header.Get("Content-Type") } cacheEntry.Hash = (*[32]byte)(fileHasher.Sum(nil)) err = gmx.Client.DB.CachedMedia.Put(ctx, cacheEntry) if err != nil { log.Err(err).Msg("Failed to save cache entry") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to save cache entry: %v", err)).Write(w) return } cachePath := gmx.cacheEntryToPath(cacheEntry) err = os.MkdirAll(filepath.Dir(cachePath), 0700) if err != nil { log.Err(err).Msg("Failed to create cache directory") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to create cache directory: %v", err)).Write(w) return } err = os.Rename(tempFile.Name(), cachePath) if err != nil { log.Err(err).Msg("Failed to rename temporary file") mautrix.MUnknown.WithMessage(fmt.Sprintf("Failed to rename temp file: %v", err)).Write(w) return } gmx.downloadMediaFromCache(ctx, w, cacheEntry, true) }