mirror of
https://github.com/tulir/gomuks.git
synced 2025-04-20 10:33:41 -05:00
245 lines
7.8 KiB
Go
245 lines
7.8 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package hicli
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/rs/zerolog"
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
|
|
"go.mau.fi/gomuks/pkg/hicli/database"
|
|
)
|
|
|
|
var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
|
|
|
|
func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) {
|
|
events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(events) == 0 {
|
|
return events, nil
|
|
}
|
|
firstRoomID := events[0].RoomID
|
|
allInSameRoom := true
|
|
for _, evt := range events {
|
|
h.ReprocessExistingEvent(ctx, evt)
|
|
if evt.RoomID != firstRoomID {
|
|
allInSameRoom = false
|
|
break
|
|
}
|
|
}
|
|
if allInSameRoom {
|
|
err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events)
|
|
if err != nil {
|
|
return events, fmt.Errorf("failed to fill last edit row IDs: %w", err)
|
|
}
|
|
err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events)
|
|
if err != nil {
|
|
return events, fmt.Errorf("failed to fill reaction counts: %w", err)
|
|
}
|
|
} else {
|
|
// TODO slow path where events are collected and filling is done one room at a time?
|
|
}
|
|
return events, nil
|
|
}
|
|
|
|
func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
|
|
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
|
|
return nil, fmt.Errorf("failed to get event from database: %w", err)
|
|
} else if evt != nil {
|
|
h.ReprocessExistingEvent(ctx, evt)
|
|
return evt, nil
|
|
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
|
|
return nil, fmt.Errorf("failed to get event from server: %w", err)
|
|
} else {
|
|
return h.processEvent(ctx, serverEvt, nil, nil, false)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
|
|
var evts []*event.Event
|
|
if refetch {
|
|
resp, err := h.Client.StateAsArray(ctx, roomID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to refetch state: %w", err)
|
|
}
|
|
evts = resp
|
|
} else if fetchMembers {
|
|
resp, err := h.Client.Members(ctx, roomID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch members: %w", err)
|
|
}
|
|
evts = resp.Chunk
|
|
}
|
|
if evts != nil {
|
|
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
room, err := h.DB.Room.Get(ctx, roomID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get room from database: %w", err)
|
|
}
|
|
updatedRoom := &database.Room{
|
|
ID: room.ID,
|
|
HasMemberList: true,
|
|
}
|
|
entries := make([]*database.CurrentStateEntry, len(evts))
|
|
for i, evt := range evts {
|
|
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
|
|
}
|
|
entries[i] = &database.CurrentStateEntry{
|
|
EventType: evt.Type,
|
|
StateKey: *evt.StateKey,
|
|
EventRowID: dbEvt.RowID,
|
|
}
|
|
if evt.Type == event.StateMember {
|
|
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
|
|
} else {
|
|
processImportantEvent(ctx, evt, room, updatedRoom)
|
|
}
|
|
}
|
|
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
|
if roomChanged {
|
|
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save room data: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return h.DB.CurrentState.GetAll(ctx, roomID)
|
|
}
|
|
|
|
type PaginationResponse struct {
|
|
Events []*database.Event `json:"events"`
|
|
HasMore bool `json:"has_more"`
|
|
}
|
|
|
|
func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
|
|
evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(evts) > 0 {
|
|
for _, evt := range evts {
|
|
h.ReprocessExistingEvent(ctx, evt)
|
|
}
|
|
return &PaginationResponse{Events: evts, HasMore: true}, nil
|
|
} else {
|
|
return h.PaginateServer(ctx, roomID, limit)
|
|
}
|
|
}
|
|
|
|
func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
|
|
ctx, cancel := context.WithCancelCause(ctx)
|
|
h.paginationInterrupterLock.Lock()
|
|
if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
|
|
h.paginationInterrupterLock.Unlock()
|
|
return nil, ErrPaginationAlreadyInProgress
|
|
}
|
|
h.paginationInterrupter[roomID] = cancel
|
|
h.paginationInterrupterLock.Unlock()
|
|
defer func() {
|
|
h.paginationInterrupterLock.Lock()
|
|
delete(h.paginationInterrupter, roomID)
|
|
h.paginationInterrupterLock.Unlock()
|
|
}()
|
|
|
|
room, err := h.DB.Room.Get(ctx, roomID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get room from database: %w", err)
|
|
} else if room.PrevBatch == database.PrevBatchPaginationComplete {
|
|
return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
|
|
}
|
|
resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get messages from server: %w", err)
|
|
}
|
|
events := make([]*database.Event, len(resp.Chunk))
|
|
if resp.End == "" {
|
|
resp.End = database.PrevBatchPaginationComplete
|
|
}
|
|
if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 {
|
|
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set prev_batch: %w", err)
|
|
}
|
|
return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil
|
|
}
|
|
wakeupSessionRequests := false
|
|
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
if err = ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
|
|
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
|
iOffset := 0
|
|
for i, evt := range resp.Chunk {
|
|
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
|
|
if err != nil {
|
|
return err
|
|
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
|
|
return fmt.Errorf("failed to check if event exists in timeline: %w", err)
|
|
} else if exists {
|
|
zerolog.Ctx(ctx).Warn().
|
|
Int64("row_id", int64(dbEvt.RowID)).
|
|
Str("event_id", dbEvt.ID.String()).
|
|
Msg("Event already exists in timeline, skipping")
|
|
iOffset++
|
|
continue
|
|
}
|
|
events[i-iOffset] = dbEvt
|
|
eventRowIDs[i-iOffset] = events[i-iOffset].RowID
|
|
}
|
|
if iOffset >= len(events) {
|
|
events = events[:0]
|
|
return nil
|
|
}
|
|
events = events[:len(events)-iOffset]
|
|
eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
|
|
wakeupSessionRequests = len(decryptionQueue) > 0
|
|
for _, entry := range decryptionQueue {
|
|
err = h.DB.SessionRequest.Put(ctx, entry)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
|
}
|
|
}
|
|
err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
|
|
}
|
|
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to set prev_batch: %w", err)
|
|
}
|
|
var tuples []database.TimelineRowTuple
|
|
tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepend events to timeline: %w", err)
|
|
}
|
|
for i, evt := range events {
|
|
evt.TimelineRowID = tuples[i].Timeline
|
|
}
|
|
return nil
|
|
})
|
|
if err == nil && wakeupSessionRequests {
|
|
h.WakeupRequestQueue()
|
|
}
|
|
return &PaginationResponse{Events: events, HasMore: true}, err
|
|
}
|