gomuks/pkg/hicli/paginate.go
Tulir Asokan ef2e95a294
Some checks failed
Go / Lint Go (old) (push) Has been cancelled
Go / Lint Go (latest) (push) Has been cancelled
JS / Lint JS (push) Has been cancelled
hicli/paginate: fill prev content when resyncing state
2025-04-13 23:08:29 +03:00

375 lines
13 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"slices"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from database: %w", err)
} else if evt != nil {
h.ReprocessExistingEvent(ctx, evt)
return evt, nil
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from server: %w", err)
} else {
return h.processEvent(ctx, serverEvt, nil, nil, false)
}
}
func (h *HiClient) GetUnredactedEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from database: %w", err)
// TODO this check doesn't handle events which keep some fields on redaction
} else if evt != nil && len(evt.Content) > 2 {
h.ReprocessExistingEvent(ctx, evt)
return evt, nil
} else if serverEvt, err := h.Client.GetUnredactedEventContent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from server: %w", err)
} else if redactedServerEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get redacted event from server: %w", err)
// TODO this check will have false positives on actually empty events
} else if len(serverEvt.Content.VeryRaw) == 2 {
return nil, fmt.Errorf("server didn't return content")
} else {
serverEvt.Unsigned.RedactedBecause = redactedServerEvt.Unsigned.RedactedBecause
return h.processEvent(ctx, serverEvt, nil, nil, false)
}
}
func (h *HiClient) processGetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch, dispatchEvt bool) error {
var evts []*event.Event
if refetch {
resp, err := h.Client.StateAsArray(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to refetch state: %w", err)
}
evts = resp
} else if fetchMembers {
resp, err := h.Client.Members(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to fetch members: %w", err)
}
evts = resp.Chunk
}
if evts == nil {
return nil
}
dbEvts := make([]*database.Event, len(evts))
currentStateEntries := make([]*database.CurrentStateEntry, len(evts))
mediaReferenceEntries := make([]*database.MediaReference, len(evts))
mediaCacheEntries := make([]*database.PlainMedia, 0, len(evts))
for i, evt := range evts {
if err := h.fillPrevContent(ctx, evt); err != nil {
return err
}
dbEvts[i] = database.MautrixToEvent(evt)
currentStateEntries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
}
var mediaURL string
if evt.Type == event.StateMember {
currentStateEntries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
mediaURL, _ = evt.Content.Raw["avatar_url"].(string)
} else if evt.Type == event.StateRoomAvatar {
mediaURL, _ = evt.Content.Raw["url"].(string)
}
if mxc := id.ContentURIString(mediaURL).ParseOrIgnore(); mxc.IsValid() {
mediaCacheEntries = append(mediaCacheEntries, (*database.PlainMedia)(&mxc))
mediaReferenceEntries[i] = &database.MediaReference{
MediaMXC: mxc,
}
}
}
return h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room from database: %w", err)
} else if room == nil {
return fmt.Errorf("room not found")
}
updatedRoom := &database.Room{
ID: room.ID,
HasMemberList: true,
}
err = h.DB.Event.MassUpsertState(ctx, dbEvts)
if err != nil {
return fmt.Errorf("failed to save events: %w", err)
}
sdc := &spaceDataCollector{}
for i := range currentStateEntries {
currentStateEntries[i].EventRowID = dbEvts[i].RowID
if mediaReferenceEntries[i] != nil {
mediaReferenceEntries[i].EventRowID = dbEvts[i].RowID
}
if evts[i].Type != event.StateMember {
processImportantEvent(ctx, evts[i], room, updatedRoom, dbEvts[i].RowID, sdc)
}
}
err = h.DB.Media.AddMany(ctx, mediaCacheEntries)
if err != nil {
return fmt.Errorf("failed to save media cache entries: %w", err)
}
mediaReferenceEntries = slices.DeleteFunc(mediaReferenceEntries, func(reference *database.MediaReference) bool {
return reference == nil
})
err = h.DB.Media.AddManyReferences(ctx, mediaReferenceEntries)
if err != nil {
return fmt.Errorf("failed to save media reference entries: %w", err)
}
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, currentStateEntries)
if err != nil {
return fmt.Errorf("failed to save current state entries: %w", err)
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
// TODO dispatch space edge changes if something changed? (fairly unlikely though)
err = sdc.Apply(ctx, room, h.DB.SpaceEdge)
if err != nil {
return err
}
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
if dispatchEvt {
h.EventHandler(&SyncComplete{
Rooms: map[id.RoomID]*SyncRoom{
roomID: {
Meta: room,
},
},
})
}
}
return nil
})
}
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, includeMembers, fetchMembers, refetch bool) ([]*database.Event, error) {
if fetchMembers || refetch {
if !includeMembers {
go func(ctx context.Context) {
err := h.processGetRoomState(ctx, roomID, fetchMembers, refetch, true)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to fetch room state in background")
}
}(context.WithoutCancel(ctx))
} else {
err := h.processGetRoomState(ctx, roomID, fetchMembers, refetch, false)
if err != nil {
return nil, err
}
}
}
if !includeMembers {
return h.DB.CurrentState.GetAllExceptMembers(ctx, roomID)
}
return h.DB.CurrentState.GetAll(ctx, roomID)
}
type PaginationResponse struct {
Events []*database.Event `json:"events"`
Receipts map[id.EventID][]*database.Receipt `json:"receipts"`
RelatedEvents []*database.Event `json:"related_events"`
HasMore bool `json:"has_more"`
}
func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
if err != nil {
return nil, err
}
var resp *PaginationResponse
if len(evts) > 0 {
for _, evt := range evts {
h.ReprocessExistingEvent(ctx, evt)
}
resp = &PaginationResponse{Events: evts, HasMore: true}
} else {
resp, err = h.PaginateServer(ctx, roomID, limit)
if err != nil {
return nil, err
}
}
resp.RelatedEvents = make([]*database.Event, 0)
eventIDs := make([]id.EventID, len(resp.Events))
eventMap := make(map[id.EventID]struct{})
for i := len(resp.Events) - 1; i >= 0; i-- {
evt := resp.Events[i]
eventIDs[i] = evt.ID
eventMap[evt.ID] = struct{}{}
replyTo := evt.GetReplyTo()
if replyTo != "" {
_, replyToAdded := eventMap[replyTo]
if !replyToAdded {
dbEvt, err := h.DB.Event.GetByID(ctx, replyTo)
if err != nil {
return nil, fmt.Errorf("failed to get reply-to event: %w", err)
} else if dbEvt != nil {
resp.RelatedEvents = append(resp.RelatedEvents, dbEvt)
eventMap[replyTo] = struct{}{}
}
}
}
}
resp.Receipts, err = h.GetReceipts(ctx, roomID, eventIDs)
if err != nil {
return nil, fmt.Errorf("failed to get receipts: %w", err)
}
return resp, nil
}
func (h *HiClient) GetReceipts(ctx context.Context, roomID id.RoomID, eventIDs []id.EventID) (map[id.EventID][]*database.Receipt, error) {
receipts, err := h.DB.Receipt.GetManyRead(ctx, roomID, eventIDs)
if err != nil {
return nil, err
}
encounteredUsers := map[id.UserID]struct{}{
// Never include own receipts
h.Account.UserID: {},
}
// If there are multiple receipts (e.g. due to threads), only keep the one for the latest event (first in the array)
// The input event IDs are already sorted in reverse chronological order
for _, evtID := range eventIDs {
receiptArr := receipts[evtID]
i := 0
for _, receipt := range receiptArr {
_, alreadyEncountered := encounteredUsers[receipt.UserID]
if alreadyEncountered {
continue
}
// Clear room ID for efficiency
receipt.RoomID = ""
encounteredUsers[receipt.UserID] = struct{}{}
receiptArr[i] = receipt
i++
}
if len(receiptArr) > 0 && i < len(receiptArr) {
receipts[evtID] = receiptArr[:i]
}
}
return receipts, nil
}
func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(context.Canceled)
h.paginationInterrupterLock.Lock()
if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
h.paginationInterrupterLock.Unlock()
return nil, ErrPaginationAlreadyInProgress
}
h.paginationInterrupter[roomID] = cancel
h.paginationInterrupterLock.Unlock()
defer func() {
h.paginationInterrupterLock.Lock()
delete(h.paginationInterrupter, roomID)
h.paginationInterrupterLock.Unlock()
}()
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room from database: %w", err)
} else if room.PrevBatch == database.PrevBatchPaginationComplete {
return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
}
resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
if err != nil {
return nil, fmt.Errorf("failed to get messages from server: %w", err)
}
events := make([]*database.Event, len(resp.Chunk))
if resp.End == "" {
resp.End = database.PrevBatchPaginationComplete
}
if len(resp.Chunk) == 0 {
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return nil, fmt.Errorf("failed to set prev_batch: %w", err)
}
return &PaginationResponse{Events: events, HasMore: resp.End != database.PrevBatchPaginationComplete}, nil
}
wakeupSessionRequests := false
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
if err = ctx.Err(); err != nil {
return err
}
eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
iOffset := 0
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
if err != nil {
return err
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
return fmt.Errorf("failed to check if event exists in timeline: %w", err)
} else if exists {
zerolog.Ctx(ctx).Warn().
Int64("row_id", int64(dbEvt.RowID)).
Str("event_id", dbEvt.ID.String()).
Msg("Event already exists in timeline, skipping")
iOffset++
continue
}
events[i-iOffset] = dbEvt
eventRowIDs[i-iOffset] = events[i-iOffset].RowID
}
if iOffset >= len(events) {
events = events[:0]
return nil
}
events = events[:len(events)-iOffset]
eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
wakeupSessionRequests = len(decryptionQueue) > 0
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
err = h.DB.Event.FillReactionCounts(ctx, roomID, events)
if err != nil {
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
if err != nil {
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return fmt.Errorf("failed to set prev_batch: %w", err)
}
var tuples []database.TimelineRowTuple
tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
if err != nil {
return fmt.Errorf("failed to prepend events to timeline: %w", err)
}
for i, evt := range events {
evt.TimelineRowID = tuples[i].Timeline
}
return nil
})
if err == nil && wakeupSessionRequests {
h.WakeupRequestQueue()
}
return &PaginationResponse{Events: events, HasMore: resp.End != database.PrevBatchPaginationComplete}, err
}