diff --git a/go.mod b/go.mod index 51aba19..1b12039 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.2-0.20241030110711-b3e597e16b74 + go.mau.fi/util v0.8.2-0.20241112213434-d05f63473223 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.28.0 golang.org/x/image v0.21.0 diff --git a/go.sum b/go.sum index f259946..186e7a2 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.2-0.20241030110711-b3e597e16b74 h1:hzVVXFEIQWefBlokVlQ2nr7EzRnMdMLF+K+kqWsm6OE= -go.mau.fi/util v0.8.2-0.20241030110711-b3e597e16b74/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= +go.mau.fi/util v0.8.2-0.20241112213434-d05f63473223 h1:XvAisCJ+cjg2wGndop+KrQMwAoZCjoa52J40hlHNHR0= +go.mau.fi/util v0.8.2-0.20241112213434-d05f63473223/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= diff --git a/pkg/hicli/database/event.go b/pkg/hicli/database/event.go index 8a5b09b..f7149dd 100644 --- a/pkg/hicli/database/event.go +++ b/pkg/hicli/database/event.go @@ -11,6 +11,7 @@ import ( "database/sql" "encoding/json" "fmt" + "slices" "strings" "time" @@ -127,6 +128,39 @@ func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, return } +var stateEventMassInserter = dbutil.NewMassInsertBuilder[*Event, [1]any]( + strings.ReplaceAll(upsertEventQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)", "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"), + "($1, $%d, $%d, $%d, $%d, $%d, $%d, NULL, NULL, $%d, NULL, $%d, $%d, NULL, NULL, NULL, NULL, NULL, '{}', 0, 0)", +) + +var massInsertConverter = dbutil.ConvertRowFn[EventRowID](dbutil.ScanSingleColumn[EventRowID]) + +func (e *Event) GetMassInsertValues() [9]any { + return [9]any{ + e.ID, e.Sender, e.Type, e.StateKey, e.Timestamp.UnixMilli(), + unsafeJSONString(e.Content), unsafeJSONString(e.Unsigned), + e.TransactionID, e.RedactedBy, + } +} + +func (eq *EventQuery) MassUpsertState(ctx context.Context, evts []*Event) error { + for chunk := range slices.Chunk(evts, 500) { + query, params := stateEventMassInserter.Build([1]any{chunk[0].RoomID}, chunk) + i := 0 + err := massInsertConverter. + NewRowIter(eq.GetDB().Query(ctx, query, params...)). + Iter(func(t EventRowID) (bool, error) { + chunk[i].RowID = t + i++ + return true, nil + }) + if err != nil { + return err + } + } + return nil +} + func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error { return eq.Exec(ctx, updateEventIDQuery, rowID, newID) } @@ -339,6 +373,7 @@ func MautrixToEvent(evt *event.Event) *Event { Content: evt.Content.VeryRaw, MegolmSessionID: getMegolmSessionID(evt), TransactionID: evt.Unsigned.TransactionID, + Reactions: make(map[string]int), } if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") { dbEvt.TransactionID = "" diff --git a/pkg/hicli/database/media.go b/pkg/hicli/database/media.go index d060ca7..edc6698 100644 --- a/pkg/hicli/database/media.go +++ b/pkg/hicli/database/media.go @@ -51,6 +51,14 @@ const ( ` ) +var mediaReferenceMassInserter = dbutil.NewMassInsertBuilder[*MediaReference, [0]any]( + addMediaReferenceQuery, "($%d, $%d)", +) + +var mediaMassInserter = dbutil.NewMassInsertBuilder[*PlainMedia, [0]any]( + "INSERT INTO media (mxc) VALUES ($1) ON CONFLICT (mxc) DO NOTHING", "($%d)", +) + type MediaQuery struct { *dbutil.QueryHelper[*Media] } @@ -63,6 +71,28 @@ func (mq *MediaQuery) AddReference(ctx context.Context, evtRowID EventRowID, mxc return mq.Exec(ctx, addMediaReferenceQuery, evtRowID, &mxc) } +func (mq *MediaQuery) AddMany(ctx context.Context, medias []*PlainMedia) error { + for chunk := range slices.Chunk(medias, 8000) { + query, params := mediaMassInserter.Build([0]any{}, chunk) + err := mq.Exec(ctx, query, params...) + if err != nil { + return err + } + } + return nil +} + +func (mq *MediaQuery) AddManyReferences(ctx context.Context, refs []*MediaReference) error { + for chunk := range slices.Chunk(refs, 4000) { + query, params := mediaReferenceMassInserter.Build([0]any{}, chunk) + err := mq.Exec(ctx, query, params...) + if err != nil { + return err + } + } + return nil +} + func (mq *MediaQuery) Put(ctx context.Context, cm *Media) error { return mq.Exec(ctx, upsertMediaQuery, cm.sqlVariables()...) } @@ -164,3 +194,18 @@ func (m *Media) ContentDisposition() string { } return "attachment" } + +type MediaReference struct { + EventRowID EventRowID + MediaMXC id.ContentURI +} + +func (mr *MediaReference) GetMassInsertValues() [2]any { + return [2]any{mr.EventRowID, &mr.MediaMXC} +} + +type PlainMedia id.ContentURI + +func (pm *PlainMedia) GetMassInsertValues() [1]any { + return [1]any{(*id.ContentURI)(pm)} +} diff --git a/pkg/hicli/paginate.go b/pkg/hicli/paginate.go index bb412ed..e492160 100644 --- a/pkg/hicli/paginate.go +++ b/pkg/hicli/paginate.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "slices" "github.com/rs/zerolog" "maunium.net/go/mautrix" @@ -83,6 +84,30 @@ func (h *HiClient) processGetRoomState(ctx context.Context, roomID id.RoomID, fe if evts == nil { return nil } + dbEvts := make([]*database.Event, len(evts)) + currentStateEntries := make([]*database.CurrentStateEntry, len(evts)) + mediaReferenceEntries := make([]*database.MediaReference, len(evts)) + mediaCacheEntries := make([]*database.PlainMedia, 0, len(evts)) + for i, evt := range evts { + dbEvts[i] = database.MautrixToEvent(evt) + currentStateEntries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + } + var mediaURL string + if evt.Type == event.StateMember { + currentStateEntries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) + mediaURL, _ = evt.Content.Raw["avatar_url"].(string) + } else if evt.Type == event.StateRoomAvatar { + mediaURL, _ = evt.Content.Raw["url"].(string) + } + if mxc := id.ContentURIString(mediaURL).ParseOrIgnore(); mxc.IsValid() { + mediaCacheEntries = append(mediaCacheEntries, (*database.PlainMedia)(&mxc)) + mediaReferenceEntries[i] = &database.MediaReference{ + MediaMXC: mxc, + } + } + } return h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { room, err := h.DB.Room.Get(ctx, roomID) if err != nil { @@ -92,26 +117,33 @@ func (h *HiClient) processGetRoomState(ctx context.Context, roomID id.RoomID, fe ID: room.ID, HasMemberList: true, } - entries := make([]*database.CurrentStateEntry, len(evts)) - for i, evt := range evts { - dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false) - if err != nil { - return fmt.Errorf("failed to process event %s: %w", evt.ID, err) + err = h.DB.Event.MassUpsertState(ctx, dbEvts) + if err != nil { + return fmt.Errorf("failed to save events: %w", err) + } + for i := range currentStateEntries { + currentStateEntries[i].EventRowID = dbEvts[i].RowID + if mediaReferenceEntries[i] != nil { + mediaReferenceEntries[i].EventRowID = dbEvts[i].RowID } - entries[i] = &database.CurrentStateEntry{ - EventType: evt.Type, - StateKey: *evt.StateKey, - EventRowID: dbEvt.RowID, - } - if evt.Type == event.StateMember { - entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) - } else { - processImportantEvent(ctx, evt, room, updatedRoom) + if evts[i].Type != event.StateMember { + processImportantEvent(ctx, evts[i], room, updatedRoom) } } - err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries) + err = h.DB.Media.AddMany(ctx, mediaCacheEntries) if err != nil { - return err + return fmt.Errorf("failed to save media cache entries: %w", err) + } + mediaReferenceEntries = slices.DeleteFunc(mediaReferenceEntries, func(reference *database.MediaReference) bool { + return reference == nil + }) + err = h.DB.Media.AddManyReferences(ctx, mediaReferenceEntries) + if err != nil { + return fmt.Errorf("failed to save media reference entries: %w", err) + } + err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, currentStateEntries) + if err != nil { + return fmt.Errorf("failed to save current state entries: %w", err) } roomChanged := updatedRoom.CheckChangesAndCopyInto(room) if roomChanged { diff --git a/pkg/hicli/sync.go b/pkg/hicli/sync.go index 15d0852..8c566c9 100644 --- a/pkg/hicli/sync.go +++ b/pkg/hicli/sync.go @@ -473,7 +473,6 @@ func (h *HiClient) processEvent( } } dbEvt := database.MautrixToEvent(evt) - dbEvt.Reactions = make(map[string]int) contentWithoutFallback := removeReplyFallback(evt) if contentWithoutFallback != nil { dbEvt.Content = contentWithoutFallback