gomuks/pkg/hicli/paginate.go

245 lines
7.8 KiB
Go

// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) {
events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...)
if err != nil {
return nil, err
} else if len(events) == 0 {
return events, nil
}
firstRoomID := events[0].RoomID
allInSameRoom := true
for _, evt := range events {
h.ReprocessExistingEvent(ctx, evt)
if evt.RoomID != firstRoomID {
allInSameRoom = false
break
}
}
if allInSameRoom {
err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events)
if err != nil {
return events, fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events)
if err != nil {
return events, fmt.Errorf("failed to fill reaction counts: %w", err)
}
} else {
// TODO slow path where events are collected and filling is done one room at a time?
}
return events, nil
}
func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from database: %w", err)
} else if evt != nil {
h.ReprocessExistingEvent(ctx, evt)
return evt, nil
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from server: %w", err)
} else {
return h.processEvent(ctx, serverEvt, nil, nil, false)
}
}
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
var evts []*event.Event
if refetch {
resp, err := h.Client.StateAsArray(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to refetch state: %w", err)
}
evts = resp
} else if fetchMembers {
resp, err := h.Client.Members(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to fetch members: %w", err)
}
evts = resp.Chunk
}
if evts != nil {
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room from database: %w", err)
}
updatedRoom := &database.Room{
ID: room.ID,
HasMemberList: true,
}
entries := make([]*database.CurrentStateEntry, len(evts))
for i, evt := range evts {
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
if err != nil {
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
}
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
}
if evt.Type == event.StateMember {
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
} else {
processImportantEvent(ctx, evt, room, updatedRoom)
}
}
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
if err != nil {
return err
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
return nil
})
if err != nil {
return nil, err
}
}
return h.DB.CurrentState.GetAll(ctx, roomID)
}
type PaginationResponse struct {
Events []*database.Event `json:"events"`
HasMore bool `json:"has_more"`
}
func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
if err != nil {
return nil, err
} else if len(evts) > 0 {
for _, evt := range evts {
h.ReprocessExistingEvent(ctx, evt)
}
return &PaginationResponse{Events: evts, HasMore: true}, nil
} else {
return h.PaginateServer(ctx, roomID, limit)
}
}
func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
ctx, cancel := context.WithCancelCause(ctx)
h.paginationInterrupterLock.Lock()
if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
h.paginationInterrupterLock.Unlock()
return nil, ErrPaginationAlreadyInProgress
}
h.paginationInterrupter[roomID] = cancel
h.paginationInterrupterLock.Unlock()
defer func() {
h.paginationInterrupterLock.Lock()
delete(h.paginationInterrupter, roomID)
h.paginationInterrupterLock.Unlock()
}()
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room from database: %w", err)
} else if room.PrevBatch == database.PrevBatchPaginationComplete {
return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
}
resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
if err != nil {
return nil, fmt.Errorf("failed to get messages from server: %w", err)
}
events := make([]*database.Event, len(resp.Chunk))
if resp.End == "" {
resp.End = database.PrevBatchPaginationComplete
}
if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 {
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return nil, fmt.Errorf("failed to set prev_batch: %w", err)
}
return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil
}
wakeupSessionRequests := false
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
if err = ctx.Err(); err != nil {
return err
}
eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
iOffset := 0
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
if err != nil {
return err
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
return fmt.Errorf("failed to check if event exists in timeline: %w", err)
} else if exists {
zerolog.Ctx(ctx).Warn().
Int64("row_id", int64(dbEvt.RowID)).
Str("event_id", dbEvt.ID.String()).
Msg("Event already exists in timeline, skipping")
iOffset++
continue
}
events[i-iOffset] = dbEvt
eventRowIDs[i-iOffset] = events[i-iOffset].RowID
}
if iOffset >= len(events) {
events = events[:0]
return nil
}
events = events[:len(events)-iOffset]
eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
wakeupSessionRequests = len(decryptionQueue) > 0
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
if err != nil {
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return fmt.Errorf("failed to set prev_batch: %w", err)
}
var tuples []database.TimelineRowTuple
tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
if err != nil {
return fmt.Errorf("failed to prepend events to timeline: %w", err)
}
for i, evt := range events {
evt.TimelineRowID = tuples[i].Timeline
}
return nil
})
if err == nil && wakeupSessionRequests {
h.WakeupRequestQueue()
}
return &PaginationResponse{Events: events, HasMore: true}, err
}