mirror of
https://github.com/tulir/gomuks.git
synced 2025-04-20 10:33:41 -05:00

Edits will now use a different HTML -> markdown converter than what is used to generate the body. This allows the plaintext body to have a plain shortcode for custom emojis, while still having the raw data for edits. Additionally, for sent events, the raw input is saved locally, which allows preserving commands and other such things. A future extension may store the raw input in a custom field in the Matrix event to allow lossless edits of messages sent from other clients.
507 lines
16 KiB
Go
507 lines
16 KiB
Go
// Copyright (c) 2024 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/tidwall/gjson"
|
|
"go.mau.fi/util/dbutil"
|
|
"go.mau.fi/util/exgjson"
|
|
"go.mau.fi/util/jsontime"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
const (
|
|
getEventBaseQuery = `
|
|
SELECT rowid, -1,
|
|
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
|
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
|
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
|
FROM event
|
|
`
|
|
getEventByRowID = getEventBaseQuery + `WHERE rowid = $1`
|
|
getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)`
|
|
getEventByID = getEventBaseQuery + `WHERE event_id = $1`
|
|
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
|
|
insertEventBaseQuery = `
|
|
INSERT INTO event (
|
|
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
|
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
|
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
|
|
`
|
|
insertEventQuery = insertEventBaseQuery + `RETURNING rowid`
|
|
upsertEventQuery = insertEventBaseQuery + `
|
|
ON CONFLICT (event_id) DO UPDATE
|
|
SET decrypted=COALESCE(event.decrypted, excluded.decrypted),
|
|
decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type),
|
|
redacted_by=COALESCE(event.redacted_by, excluded.redacted_by),
|
|
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END,
|
|
send_error=excluded.send_error,
|
|
timestamp=excluded.timestamp,
|
|
unsigned=COALESCE(excluded.unsigned, event.unsigned),
|
|
local_content=COALESCE(excluded.local_content, event.local_content)
|
|
ON CONFLICT (transaction_id) DO UPDATE
|
|
SET event_id=excluded.event_id,
|
|
timestamp=excluded.timestamp,
|
|
unsigned=excluded.unsigned
|
|
RETURNING rowid
|
|
`
|
|
updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1`
|
|
updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1`
|
|
updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1`
|
|
updateEventLocalContentQuery = `UPDATE event SET local_content = $2 WHERE rowid = $1`
|
|
updateEventEncryptedContentQuery = `UPDATE event SET content = $2, megolm_session_id = $3 WHERE rowid = $1`
|
|
getEventReactionsQuery = getEventBaseQuery + `
|
|
WHERE room_id = ?
|
|
AND type = 'm.reaction'
|
|
AND relation_type = 'm.annotation'
|
|
AND redacted_by IS NULL
|
|
AND relates_to IN (%s)
|
|
`
|
|
getEventEditRowIDsQuery = `
|
|
SELECT main.event_id, edit.rowid
|
|
FROM event main
|
|
JOIN event edit ON
|
|
edit.room_id = main.room_id
|
|
AND edit.relates_to = main.event_id
|
|
AND edit.relation_type = 'm.replace'
|
|
AND edit.type = main.type
|
|
AND edit.sender = main.sender
|
|
AND edit.redacted_by IS NULL
|
|
WHERE main.event_id IN (%s)
|
|
ORDER BY main.event_id, edit.timestamp
|
|
`
|
|
setLastEditRowIDQuery = `
|
|
UPDATE event SET last_edit_rowid = $2 WHERE event_id = $1
|
|
`
|
|
updateReactionCountsQuery = `UPDATE event SET reactions = $2 WHERE event_id = $1`
|
|
)
|
|
|
|
type EventQuery struct {
|
|
*dbutil.QueryHelper[*Event]
|
|
}
|
|
|
|
func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) {
|
|
return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID)
|
|
}
|
|
|
|
func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, error) {
|
|
return eq.QueryOne(ctx, getEventByID, eventID)
|
|
}
|
|
|
|
func (eq *EventQuery) GetByRowID(ctx context.Context, rowID EventRowID) (*Event, error) {
|
|
return eq.QueryOne(ctx, getEventByRowID, rowID)
|
|
}
|
|
|
|
func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) {
|
|
query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID)
|
|
return eq.QueryMany(ctx, query, params...)
|
|
}
|
|
|
|
func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
|
|
err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID)
|
|
if err == nil {
|
|
evt.RowID = rowID
|
|
}
|
|
return
|
|
}
|
|
|
|
func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
|
|
err = eq.GetDB().QueryRow(ctx, insertEventQuery, evt.sqlVariables()...).Scan(&rowID)
|
|
if err == nil {
|
|
evt.RowID = rowID
|
|
}
|
|
return
|
|
}
|
|
|
|
func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error {
|
|
return eq.Exec(ctx, updateEventIDQuery, rowID, newID)
|
|
}
|
|
|
|
func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sendError string) error {
|
|
return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError)
|
|
}
|
|
|
|
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error {
|
|
return eq.Exec(
|
|
ctx,
|
|
updateEventDecryptedQuery,
|
|
evt.RowID,
|
|
unsafeJSONString(evt.Decrypted),
|
|
evt.DecryptedType,
|
|
evt.UnreadType,
|
|
dbutil.JSONPtr(evt.LocalContent),
|
|
)
|
|
}
|
|
|
|
func (eq *EventQuery) UpdateLocalContent(ctx context.Context, evt *Event) error {
|
|
return eq.Exec(ctx, updateEventLocalContentQuery, evt.RowID, dbutil.JSONPtr(evt.LocalContent))
|
|
}
|
|
|
|
func (eq *EventQuery) UpdateEncryptedContent(ctx context.Context, evt *Event) error {
|
|
return eq.Exec(ctx, updateEventEncryptedContentQuery, evt.RowID, unsafeJSONString(evt.Content), evt.MegolmSessionID)
|
|
}
|
|
|
|
func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error {
|
|
eventIDs := make([]id.EventID, 0, len(events))
|
|
eventMap := make(map[id.EventID]*Event)
|
|
for _, evt := range events {
|
|
if evt.Reactions == nil {
|
|
eventIDs = append(eventIDs, evt.ID)
|
|
eventMap[evt.ID] = evt
|
|
}
|
|
}
|
|
if len(eventIDs) == 0 {
|
|
return nil
|
|
}
|
|
result, err := eq.GetReactions(ctx, roomID, eventIDs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for evtID, res := range result {
|
|
eventMap[evtID].Reactions = res.Counts
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error {
|
|
eventIDs := make([]id.EventID, len(events))
|
|
eventMap := make(map[id.EventID]*Event)
|
|
for i, evt := range events {
|
|
if evt.LastEditRowID == nil {
|
|
eventIDs[i] = evt.ID
|
|
eventMap[evt.ID] = evt
|
|
}
|
|
}
|
|
return eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
result, err := eq.GetEditRowIDs(ctx, roomID, eventIDs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for evtID, res := range result {
|
|
lastEditRowID := res[len(res)-1]
|
|
eventMap[evtID].LastEditRowID = &lastEditRowID
|
|
delete(eventMap, evtID)
|
|
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
var zero EventRowID
|
|
for evtID, evt := range eventMap {
|
|
evt.LastEditRowID = &zero
|
|
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, zero)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
var reactionKeyPath = exgjson.Path("m.relates_to", "key")
|
|
|
|
type GetReactionsResult struct {
|
|
Events []*Event
|
|
Counts map[string]int
|
|
}
|
|
|
|
func buildMultiEventGetFunction[T any](preParams []any, eventIDs []T, query string) (string, []any) {
|
|
params := make([]any, len(preParams)+len(eventIDs))
|
|
copy(params, preParams)
|
|
for i, evtID := range eventIDs {
|
|
params[i+len(preParams)] = evtID
|
|
}
|
|
placeholders := strings.Repeat("?,", len(eventIDs))
|
|
placeholders = placeholders[:len(placeholders)-1]
|
|
return fmt.Sprintf(query, placeholders), params
|
|
}
|
|
|
|
type editRowIDTuple struct {
|
|
eventID id.EventID
|
|
editRowID EventRowID
|
|
}
|
|
|
|
func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) {
|
|
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventEditRowIDsQuery)
|
|
rows, err := eq.GetDB().Query(ctx, query, params...)
|
|
output := make(map[id.EventID][]EventRowID)
|
|
return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) {
|
|
err = row.Scan(&tuple.eventID, &tuple.editRowID)
|
|
return
|
|
}, err).Iter(func(tuple editRowIDTuple) (bool, error) {
|
|
output[tuple.eventID] = append(output[tuple.eventID], tuple.editRowID)
|
|
return true, nil
|
|
})
|
|
}
|
|
|
|
func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID]*GetReactionsResult, error) {
|
|
result := make(map[id.EventID]*GetReactionsResult, len(eventIDs))
|
|
for _, evtID := range eventIDs {
|
|
result[evtID] = &GetReactionsResult{Counts: make(map[string]int)}
|
|
}
|
|
return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
|
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventReactionsQuery)
|
|
events, err := eq.QueryMany(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
} else if len(events) == 0 {
|
|
return nil
|
|
}
|
|
for _, evt := range events {
|
|
dest := result[evt.RelatesTo]
|
|
dest.Events = append(dest.Events, evt)
|
|
keyRes := gjson.GetBytes(evt.Content, reactionKeyPath)
|
|
if keyRes.Type == gjson.String {
|
|
dest.Counts[keyRes.Str]++
|
|
}
|
|
}
|
|
for evtID, res := range result {
|
|
if len(res.Counts) > 0 {
|
|
err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
type EventRowID int64
|
|
|
|
func (m EventRowID) GetMassInsertValues() [1]any {
|
|
return [1]any{m}
|
|
}
|
|
|
|
type LocalContent struct {
|
|
SanitizedHTML string `json:"sanitized_html,omitempty"`
|
|
HTMLVersion int `json:"html_version,omitempty"`
|
|
WasPlaintext bool `json:"was_plaintext,omitempty"`
|
|
BigEmoji bool `json:"big_emoji,omitempty"`
|
|
HasMath bool `json:"has_math,omitempty"`
|
|
EditSource string `json:"edit_source,omitempty"`
|
|
}
|
|
|
|
type Event struct {
|
|
RowID EventRowID `json:"rowid"`
|
|
TimelineRowID TimelineRowID `json:"timeline_rowid"`
|
|
|
|
RoomID id.RoomID `json:"room_id"`
|
|
ID id.EventID `json:"event_id"`
|
|
Sender id.UserID `json:"sender"`
|
|
Type string `json:"type"`
|
|
StateKey *string `json:"state_key,omitempty"`
|
|
Timestamp jsontime.UnixMilli `json:"timestamp"`
|
|
|
|
Content json.RawMessage `json:"content"`
|
|
Decrypted json.RawMessage `json:"decrypted,omitempty"`
|
|
DecryptedType string `json:"decrypted_type,omitempty"`
|
|
Unsigned json.RawMessage `json:"unsigned,omitempty"`
|
|
LocalContent *LocalContent `json:"local_content,omitempty"`
|
|
|
|
TransactionID string `json:"transaction_id,omitempty"`
|
|
|
|
RedactedBy id.EventID `json:"redacted_by,omitempty"`
|
|
RelatesTo id.EventID `json:"relates_to,omitempty"`
|
|
RelationType event.RelationType `json:"relation_type,omitempty"`
|
|
|
|
MegolmSessionID id.SessionID `json:"-"`
|
|
DecryptionError string `json:"decryption_error,omitempty"`
|
|
SendError string `json:"send_error,omitempty"`
|
|
|
|
Reactions map[string]int `json:"reactions,omitempty"`
|
|
LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
|
|
UnreadType UnreadType `json:"unread_type,omitempty"`
|
|
}
|
|
|
|
func MautrixToEvent(evt *event.Event) *Event {
|
|
dbEvt := &Event{
|
|
RoomID: evt.RoomID,
|
|
ID: evt.ID,
|
|
Sender: evt.Sender,
|
|
Type: evt.Type.Type,
|
|
StateKey: evt.StateKey,
|
|
Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
|
|
Content: evt.Content.VeryRaw,
|
|
MegolmSessionID: getMegolmSessionID(evt),
|
|
TransactionID: evt.Unsigned.TransactionID,
|
|
}
|
|
if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") {
|
|
dbEvt.TransactionID = ""
|
|
}
|
|
dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt)
|
|
dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
|
|
if evt.Unsigned.RedactedBecause != nil {
|
|
dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
|
|
}
|
|
return dbEvt
|
|
}
|
|
|
|
func (e *Event) AsRawMautrix() *event.Event {
|
|
if e == nil {
|
|
return nil
|
|
}
|
|
evt := &event.Event{
|
|
RoomID: e.RoomID,
|
|
ID: e.ID,
|
|
Sender: e.Sender,
|
|
Type: event.Type{Type: e.Type, Class: event.MessageEventType},
|
|
StateKey: e.StateKey,
|
|
Timestamp: e.Timestamp.UnixMilli(),
|
|
Content: event.Content{VeryRaw: e.Content},
|
|
}
|
|
if e.Decrypted != nil {
|
|
evt.Content.VeryRaw = e.Decrypted
|
|
evt.Type.Type = e.DecryptedType
|
|
evt.Mautrix.WasEncrypted = true
|
|
}
|
|
if e.StateKey != nil {
|
|
evt.Type.Class = event.StateEventType
|
|
}
|
|
_ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
|
|
return evt
|
|
}
|
|
|
|
func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
|
|
var timestamp int64
|
|
var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString
|
|
err := row.Scan(
|
|
&e.RowID,
|
|
&e.TimelineRowID,
|
|
&e.RoomID,
|
|
&e.ID,
|
|
&e.Sender,
|
|
&e.Type,
|
|
&e.StateKey,
|
|
×tamp,
|
|
(*[]byte)(&e.Content),
|
|
(*[]byte)(&e.Decrypted),
|
|
&decryptedType,
|
|
(*[]byte)(&e.Unsigned),
|
|
dbutil.JSON{Data: &e.LocalContent},
|
|
&transactionID,
|
|
&redactedBy,
|
|
&relatesTo,
|
|
&relationType,
|
|
&megolmSessionID,
|
|
&decryptionError,
|
|
&sendError,
|
|
dbutil.JSON{Data: &e.Reactions},
|
|
&e.LastEditRowID,
|
|
&e.UnreadType,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e.Timestamp = jsontime.UM(time.UnixMilli(timestamp))
|
|
e.TransactionID = transactionID.String
|
|
e.RedactedBy = id.EventID(redactedBy.String)
|
|
e.RelatesTo = id.EventID(relatesTo.String)
|
|
e.RelationType = event.RelationType(relationType.String)
|
|
e.MegolmSessionID = id.SessionID(megolmSessionID.String)
|
|
e.DecryptedType = decryptedType.String
|
|
e.DecryptionError = decryptionError.String
|
|
e.SendError = sendError.String
|
|
return e, nil
|
|
}
|
|
|
|
var relatesToPath = exgjson.Path("m.relates_to", "event_id")
|
|
var relationTypePath = exgjson.Path("m.relates_to", "rel_type")
|
|
|
|
func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) {
|
|
if evt.StateKey != nil {
|
|
return "", ""
|
|
}
|
|
return GetRelatesToFromBytes(evt.Content.VeryRaw)
|
|
}
|
|
|
|
func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) {
|
|
results := gjson.GetManyBytes(content, relatesToPath, relationTypePath)
|
|
if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String {
|
|
return id.EventID(results[0].Str), event.RelationType(results[1].Str)
|
|
}
|
|
return "", ""
|
|
}
|
|
|
|
func getMegolmSessionID(evt *event.Event) id.SessionID {
|
|
if evt.Type != event.EventEncrypted {
|
|
return ""
|
|
}
|
|
res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
|
|
if res.Exists() && res.Type == gjson.String {
|
|
return id.SessionID(res.Str)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (e *Event) sqlVariables() []any {
|
|
var reactions any
|
|
if e.Reactions != nil {
|
|
reactions = e.Reactions
|
|
}
|
|
return []any{
|
|
e.RoomID,
|
|
e.ID,
|
|
e.Sender,
|
|
e.Type,
|
|
e.StateKey,
|
|
e.Timestamp.UnixMilli(),
|
|
unsafeJSONString(e.Content),
|
|
unsafeJSONString(e.Decrypted),
|
|
dbutil.StrPtr(e.DecryptedType),
|
|
unsafeJSONString(e.Unsigned),
|
|
dbutil.JSONPtr(e.LocalContent),
|
|
dbutil.StrPtr(e.TransactionID),
|
|
dbutil.StrPtr(e.RedactedBy),
|
|
dbutil.StrPtr(e.RelatesTo),
|
|
dbutil.StrPtr(e.RelationType),
|
|
dbutil.StrPtr(e.MegolmSessionID),
|
|
dbutil.StrPtr(e.DecryptionError),
|
|
dbutil.StrPtr(e.SendError),
|
|
dbutil.JSON{Data: reactions},
|
|
e.LastEditRowID,
|
|
e.UnreadType,
|
|
}
|
|
}
|
|
|
|
func (e *Event) GetNonPushUnreadType() UnreadType {
|
|
if e.RelationType == event.RelReplace || e.RedactedBy != "" {
|
|
return UnreadTypeNone
|
|
}
|
|
switch e.Type {
|
|
case event.EventMessage.Type, event.EventSticker.Type, event.EventUnstablePollStart.Type:
|
|
return UnreadTypeNormal
|
|
case event.EventEncrypted.Type:
|
|
switch e.DecryptedType {
|
|
case event.EventMessage.Type, event.EventSticker.Type, event.EventUnstablePollStart.Type:
|
|
return UnreadTypeNormal
|
|
}
|
|
}
|
|
return UnreadTypeNone
|
|
}
|
|
|
|
func (e *Event) CanUseForPreview() bool {
|
|
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
|
|
(e.Type == event.EventEncrypted.Type &&
|
|
(e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) &&
|
|
e.RelationType != event.RelReplace && e.RedactedBy == ""
|
|
}
|
|
|
|
func (e *Event) BumpsSortingTimestamp() bool {
|
|
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
|
|
e.RelationType != event.RelReplace
|
|
}
|