1
0
Fork 0
forked from Mirrors/gomuks

hicli/sync: invalidate outbound sessions on member change

This commit is contained in:
Tulir Asokan 2025-02-07 19:32:29 +02:00
parent 14c9291c8d
commit 4060383efa

View file

@ -93,6 +93,42 @@ func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.Res
return nil return nil
} }
func (h *HiClient) maybeDiscardOutboundSession(ctx context.Context, newMembership event.Membership, evt *event.Event) bool {
var prevMembership event.Membership = "unknown"
if evt.Unsigned.PrevContent != nil {
prevMembership = event.Membership(gjson.GetBytes(evt.Unsigned.PrevContent.VeryRaw, "membership").Str)
}
if prevMembership == "unknown" || prevMembership == "" {
cs, err := h.DB.CurrentState.Get(ctx, evt.RoomID, event.StateMember, h.Account.UserID.String())
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("room_id", evt.RoomID).
Str("user_id", evt.GetStateKey()).
Msg("Failed to get previous membership")
return false
}
prevMembership = event.Membership(gjson.GetBytes(cs.Content, "membership").Str)
}
if prevMembership == newMembership ||
(prevMembership == event.MembershipInvite && newMembership == event.MembershipJoin) ||
(prevMembership == event.MembershipBan && newMembership == event.MembershipLeave) ||
(prevMembership == event.MembershipLeave && newMembership == event.MembershipBan) {
return false
}
zerolog.Ctx(ctx).Debug().
Stringer("room_id", evt.RoomID).
Str("user_id", evt.GetStateKey()).
Str("prev_membership", string(prevMembership)).
Str("new_membership", string(newMembership)).
Msg("Got membership state change, invalidating group session in room")
err := h.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
zerolog.Ctx(ctx).Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session")
return false
}
return true
}
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
go h.asyncPostProcessSyncResponse(ctx, resp, since) go h.asyncPostProcessSyncResponse(ctx, resp, since)
@ -297,6 +333,10 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro
if err != nil { if err != nil {
return fmt.Errorf("failed to delete invited room: %w", err) return fmt.Errorf("failed to delete invited room: %w", err)
} }
err = h.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to remove outbound group session: %w", err)
}
payload := ctx.Value(syncContextKey).(*syncContext).evt payload := ctx.Value(syncContextKey).(*syncContext).evt
payload.LeftRooms = append(payload.LeftRooms, roomID) payload.LeftRooms = append(payload.LeftRooms, roomID)
return nil return nil
@ -715,6 +755,7 @@ func (h *HiClient) processStateAndTimeline(
} }
return nil return nil
} }
megolmSessionDiscarded := false
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
evt.RoomID = room.ID evt.RoomID = room.ID
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, evt.Unsigned.TransactionID != "") dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, evt.Unsigned.TransactionID != "")
@ -747,6 +788,9 @@ func (h *HiClient) processStateAndTimeline(
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
heroesChanged = true heroesChanged = true
} }
if !megolmSessionDiscarded && room.EncryptionEvent != nil {
megolmSessionDiscarded = h.maybeDiscardOutboundSession(ctx, membership, evt)
}
} else if evt.Type == event.StateElementFunctionalMembers { } else if evt.Type == event.StateElementFunctionalMembers {
heroesChanged = true heroesChanged = true
} }