From 7fbdfffd90d218b241b16b04314cb4cda9f948b8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Oct 2024 21:02:55 +0300 Subject: [PATCH] hicli/send: encrypt message asynchronously --- pkg/hicli/database/event.go | 15 ++++--- pkg/hicli/send.go | 83 +++++++++++++++++++++---------------- 2 files changed, 58 insertions(+), 40 deletions(-) diff --git a/pkg/hicli/database/event.go b/pkg/hicli/database/event.go index 1520dc4..3bcb0b0 100644 --- a/pkg/hicli/database/event.go +++ b/pkg/hicli/database/event.go @@ -59,11 +59,12 @@ const ( unsigned=excluded.unsigned RETURNING rowid ` - updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1` - updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1` - updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1` - updateEventLocalContentQuery = `UPDATE event SET local_content = $2 WHERE rowid = $1` - getEventReactionsQuery = getEventBaseQuery + ` + updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1` + updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1` + updateEventLocalContentQuery = `UPDATE event SET local_content = $2 WHERE rowid = $1` + updateEventEncryptedContentQuery = `UPDATE event SET content = $2, megolm_session_id = $3 WHERE rowid = $1` + getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? AND type = 'm.reaction' AND relation_type = 'm.annotation' @@ -150,6 +151,10 @@ func (eq *EventQuery) UpdateLocalContent(ctx context.Context, evt *Event) error return eq.Exec(ctx, updateEventLocalContentQuery, evt.RowID, dbutil.JSONPtr(evt.LocalContent)) } +func (eq *EventQuery) UpdateEncryptedContent(ctx context.Context, evt *Event) error { + return eq.Exec(ctx, updateEventEncryptedContentQuery, evt.RowID, unsafeJSONString(evt.Content), evt.MegolmSessionID) +} + func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { eventIDs := make([]id.EventID, 0) eventMap := make(map[id.EventID]*Event) diff --git a/pkg/hicli/send.go b/pkg/hicli/send.go index 9529259..8f0ba54 100644 --- a/pkg/hicli/send.go +++ b/pkg/hicli/send.go @@ -91,48 +91,36 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ } else if roomMeta == nil { return nil, fmt.Errorf("unknown room") } - var decryptedType event.Type - var decryptedContent json.RawMessage - var megolmSessionID id.SessionID - if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { - decryptedType = evtType - decryptedContent, err = json.Marshal(content) - if err != nil { - return nil, fmt.Errorf("failed to marshal event content: %w", err) - } - encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content) - if err != nil { - return nil, fmt.Errorf("failed to encrypt event: %w", err) - } - megolmSessionID = encryptedContent.SessionID - content = encryptedContent - evtType = event.EventEncrypted - } - mainContent, err := json.Marshal(content) - if err != nil { - return nil, fmt.Errorf("failed to marshal event content: %w", err) - } txnID := "hicli-" + h.Client.TxnID() - relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) dbEvt := &database.Event{ RoomID: roomID, ID: id.EventID(fmt.Sprintf("~%s", txnID)), Sender: h.Account.UserID, - Type: evtType.Type, Timestamp: jsontime.UnixMilliNow(), - Content: mainContent, - Decrypted: decryptedContent, - DecryptedType: decryptedType.Type, Unsigned: []byte("{}"), TransactionID: txnID, - RelatesTo: relatesTo, - RelationType: relationType, - MegolmSessionID: megolmSessionID, DecryptionError: "", SendError: "not sent", Reactions: map[string]int{}, LastEditRowID: ptr.Ptr(database.EventRowID(0)), } + if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { + dbEvt.Type = event.EventEncrypted.Type + dbEvt.DecryptedType = evtType.Type + dbEvt.Decrypted, err = json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + dbEvt.Content = json.RawMessage("{}") + dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Decrypted) + } else { + dbEvt.Type = evtType.Type + dbEvt.Content, err = json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + dbEvt.RelatesTo, dbEvt.RelationType = database.GetRelatesToFromBytes(dbEvt.Content) + } dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, dbEvt.AsRawMautrix()) _, err = h.DB.Event.Insert(ctx, dbEvt) if err != nil { @@ -148,13 +136,43 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ go func() { var err error defer func() { + if dbEvt.SendError != "" { + err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError) + if err2 != nil { + zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err). + Msg("Failed to update send error in database after sending failed") + } + } h.EventHandler(&SendComplete{ Event: dbEvt, Error: err, }) }() + if dbEvt.Decrypted != nil { + var encryptedContent *event.EncryptedEventContent + encryptedContent, err = h.Encrypt(ctx, roomMeta, evtType, dbEvt.Decrypted) + if err != nil { + dbEvt.SendError = fmt.Sprintf("failed to encrypt: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to encrypt event") + return + } + evtType = event.EventEncrypted + dbEvt.MegolmSessionID = encryptedContent.SessionID + dbEvt.Content, err = json.Marshal(encryptedContent) + if err != nil { + dbEvt.SendError = fmt.Sprintf("failed to marshal encrypted content: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to marshal encrypted content") + return + } + err = h.DB.Event.UpdateEncryptedContent(ctx, dbEvt) + if err != nil { + dbEvt.SendError = fmt.Sprintf("failed to save event after encryption: %v", err) + zerolog.Ctx(ctx).Err(err).Msg("Failed to save event after encryption") + return + } + } var resp *mautrix.RespSendEvent - resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{ + resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, dbEvt.Content, mautrix.ReqSendEvent{ Timestamp: dbEvt.Timestamp.UnixMilli(), TransactionID: txnID, DontEncrypt: true, @@ -162,11 +180,6 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ if err != nil { dbEvt.SendError = err.Error() err = fmt.Errorf("failed to send event: %w", err) - err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError) - if err2 != nil { - zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err). - Msg("Failed to update send error in database after sending failed") - } return } dbEvt.ID = resp.EventID