mirror of
https://github.com/tulir/gomuks.git
synced 2025-04-19 18:13:41 -05:00
215 lines
6.6 KiB
Go
215 lines
6.6 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package hicli
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/rs/zerolog"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
|
|
"go.mau.fi/gomuks/pkg/hicli/database"
|
|
)
|
|
|
|
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
|
|
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if data == nil {
|
|
return nil, nil
|
|
}
|
|
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
|
|
}
|
|
|
|
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
|
|
log := zerolog.Ctx(ctx)
|
|
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
|
|
if err != nil {
|
|
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
|
|
}
|
|
// When receiving megolm sessions in sync, wake up the request queue to ensure they get uploaded to key backup
|
|
syncCtx, ok := ctx.Value(syncContextKey).(*syncContext)
|
|
if ok {
|
|
syncCtx.shouldWakeupRequestQueue = true
|
|
}
|
|
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
|
|
return
|
|
} else if len(events) == 0 {
|
|
log.Trace().Msg("No events to retry decryption for")
|
|
return
|
|
}
|
|
decrypted := events[:0]
|
|
for _, evt := range events {
|
|
if evt.Decrypted != nil {
|
|
continue
|
|
}
|
|
|
|
var mautrixEvt *event.Event
|
|
mautrixEvt, err = h.decryptEventInto(ctx, evt.AsRawMautrix(), evt)
|
|
if err != nil {
|
|
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
|
} else {
|
|
decrypted = append(decrypted, evt)
|
|
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
|
|
}
|
|
}
|
|
if len(decrypted) > 0 {
|
|
var newPreview database.EventRowID
|
|
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
for _, evt := range decrypted {
|
|
err = h.DB.Event.UpdateDecrypted(ctx, evt)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
|
|
}
|
|
if evt.CanUseForPreview() {
|
|
var previewChanged bool
|
|
previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err)
|
|
} else if previewChanged {
|
|
newPreview = evt.RowID
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to save decrypted events")
|
|
} else {
|
|
h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID})
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) WakeupRequestQueue() {
|
|
select {
|
|
case h.requestQueueWakeup <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) RunRequestQueue(ctx context.Context) {
|
|
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
|
|
ctx = log.WithContext(ctx)
|
|
log.Info().Msg("Starting key request queue")
|
|
defer func() {
|
|
log.Info().Msg("Stopping key request queue")
|
|
}()
|
|
for {
|
|
err := h.FetchKeysForOutdatedUsers(ctx)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
|
|
}
|
|
h.uploadKeysToBackup(ctx)
|
|
madeRequests, err := h.RequestQueuedSessions(ctx)
|
|
if err != nil {
|
|
log.Err(err).Msg("Failed to handle session request queue")
|
|
} else if madeRequests {
|
|
continue
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-h.requestQueueWakeup:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
|
|
defer doneFunc()
|
|
log := zerolog.Ctx(ctx)
|
|
if !req.BackupChecked {
|
|
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
|
|
if err != nil {
|
|
log.Err(err).
|
|
Stringer("session_id", req.SessionID).
|
|
Msg("Failed to fetch session from key backup")
|
|
|
|
// TODO should this have retries instead of just storing it's checked?
|
|
req.BackupChecked = true
|
|
err = h.DB.SessionRequest.Put(ctx, req)
|
|
if err != nil {
|
|
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
|
|
}
|
|
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
|
|
req.BackupChecked = true
|
|
err = h.DB.SessionRequest.Put(ctx, req)
|
|
if err != nil {
|
|
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
|
|
}
|
|
} else {
|
|
log.Debug().Stringer("session_id", req.SessionID).
|
|
Msg("Found session with sufficiently low first known index, removing from queue")
|
|
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
|
|
if err != nil {
|
|
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
|
|
}
|
|
}
|
|
} else {
|
|
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
|
|
h.Account.UserID: {"*"},
|
|
req.Sender: {"*"},
|
|
})
|
|
//var err error
|
|
if err != nil {
|
|
log.Err(err).
|
|
Stringer("session_id", req.SessionID).
|
|
Msg("Failed to send key request")
|
|
} else {
|
|
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
|
|
req.RequestSent = true
|
|
err = h.DB.SessionRequest.Put(ctx, req)
|
|
if err != nil {
|
|
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
const MaxParallelRequests = 5
|
|
|
|
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
|
|
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
|
|
} else if len(sessions) == 0 {
|
|
return false, nil
|
|
}
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(sessions))
|
|
for _, req := range sessions {
|
|
go h.requestQueuedSession(ctx, req, wg.Done)
|
|
}
|
|
wg.Wait()
|
|
|
|
return true, err
|
|
}
|
|
|
|
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
|
|
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
|
|
if err != nil {
|
|
return err
|
|
} else if len(outdatedUsers) == 0 {
|
|
return nil
|
|
}
|
|
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// TODO backoff for users that fail to be fetched?
|
|
return nil
|
|
}
|