gomuks/pkg/hicli/decryptionqueue.go

215 lines
6.7 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"sync"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
if err != nil {
return nil, err
} else if data == nil {
return nil, nil
}
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
if err != nil {
return nil, err
}
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
}
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
log := zerolog.Ctx(ctx)
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
if err != nil {
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
}
// When receiving megolm sessions in sync, wake up the request queue to ensure they get uploaded to key backup
syncCtx, ok := ctx.Value(syncContextKey).(*syncContext)
if ok {
syncCtx.shouldWakeupRequestQueue = true
}
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
if err != nil {
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
return
} else if len(events) == 0 {
log.Trace().Msg("No events to retry decryption for")
return
}
decrypted := events[:0]
for _, evt := range events {
if evt.Decrypted != nil {
continue
}
var mautrixEvt *event.Event
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {
decrypted = append(decrypted, evt)
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
}
}
if len(decrypted) > 0 {
var newPreview database.EventRowID
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, evt := range decrypted {
err = h.DB.Event.UpdateDecrypted(ctx, evt)
if err != nil {
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
}
if evt.CanUseForPreview() {
var previewChanged bool
previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID)
if err != nil {
return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err)
} else if previewChanged {
newPreview = evt.RowID
}
}
}
return nil
})
if err != nil {
log.Err(err).Msg("Failed to save decrypted events")
} else {
h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID})
}
}
}
func (h *HiClient) WakeupRequestQueue() {
select {
case h.requestQueueWakeup <- struct{}{}:
default:
}
}
func (h *HiClient) RunRequestQueue(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
ctx = log.WithContext(ctx)
log.Info().Msg("Starting key request queue")
defer func() {
log.Info().Msg("Stopping key request queue")
}()
for {
err := h.FetchKeysForOutdatedUsers(ctx)
if err != nil {
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
}
h.uploadKeysToBackup(ctx)
madeRequests, err := h.RequestQueuedSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to handle session request queue")
} else if madeRequests {
continue
}
select {
case <-ctx.Done():
return
case <-h.requestQueueWakeup:
}
}
}
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
defer doneFunc()
log := zerolog.Ctx(ctx)
if !req.BackupChecked {
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to fetch session from key backup")
// TODO should this have retries instead of just storing it's checked?
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
}
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
}
} else {
log.Debug().Stringer("session_id", req.SessionID).
Msg("Found session with sufficiently low first known index, removing from queue")
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
}
}
} else {
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
h.Account.UserID: {"*"},
req.Sender: {"*"},
})
//var err error
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to send key request")
} else {
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
req.RequestSent = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
}
}
}
}
const MaxParallelRequests = 5
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
if err != nil {
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
} else if len(sessions) == 0 {
return false, nil
}
var wg sync.WaitGroup
wg.Add(len(sessions))
for _, req := range sessions {
go h.requestQueuedSession(ctx, req, wg.Done)
}
wg.Wait()
return true, err
}
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
if err != nil {
return err
} else if len(outdatedUsers) == 0 {
return nil
}
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
if err != nil {
return err
}
// TODO backoff for users that fail to be fetched?
return nil
}