From 4060383efa5e25383d42fbde2547e7dab8bafa8b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 7 Feb 2025 19:32:29 +0200 Subject: [PATCH] hicli/sync: invalidate outbound sessions on member change --- pkg/hicli/sync.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index ad515de..673dda8 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -93,6 +93,42 @@ func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.Res return nil } +func (h *HiClient) maybeDiscardOutboundSession(ctx context.Context, newMembership event.Membership, evt *event.Event) bool { + var prevMembership event.Membership = "unknown" + if evt.Unsigned.PrevContent != nil { + prevMembership = event.Membership(gjson.GetBytes(evt.Unsigned.PrevContent.VeryRaw, "membership").Str) + } + if prevMembership == "unknown" || prevMembership == "" { + cs, err := h.DB.CurrentState.Get(ctx, evt.RoomID, event.StateMember, h.Account.UserID.String()) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("room_id", evt.RoomID). + Str("user_id", evt.GetStateKey()). + Msg("Failed to get previous membership") + return false + } + prevMembership = event.Membership(gjson.GetBytes(cs.Content, "membership").Str) + } + if prevMembership == newMembership || + (prevMembership == event.MembershipInvite && newMembership == event.MembershipJoin) || + (prevMembership == event.MembershipBan && newMembership == event.MembershipLeave) || + (prevMembership == event.MembershipLeave && newMembership == event.MembershipBan) { + return false + } + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", evt.RoomID). + Str("user_id", evt.GetStateKey()). + Str("prev_membership", string(prevMembership)). + Str("new_membership", string(newMembership)). + Msg("Got membership state change, invalidating group session in room") + err := h.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) + if err != nil { + zerolog.Ctx(ctx).Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") + return false + } + return true +} + func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) go h.asyncPostProcessSyncResponse(ctx, resp, since) @@ -297,6 +333,10 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro if err != nil { return fmt.Errorf("failed to delete invited room: %w", err) } + err = h.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to remove outbound group session: %w", err) + } payload := ctx.Value(syncContextKey).(*syncContext).evt payload.LeftRooms = append(payload.LeftRooms, roomID) return nil @@ -715,6 +755,7 @@ func (h *HiClient) processStateAndTimeline( } return nil } + megolmSessionDiscarded := false processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { evt.RoomID = room.ID dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, evt.Unsigned.TransactionID != "") @@ -747,6 +788,9 @@ func (h *HiClient) processStateAndTimeline( if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { heroesChanged = true } + if !megolmSessionDiscarded && room.EncryptionEvent != nil { + megolmSessionDiscarded = h.maybeDiscardOutboundSession(ctx, membership, evt) + } } else if evt.Type == event.StateElementFunctionalMembers { heroesChanged = true }