all: move hicli from mautrix-go and add more features

This commit is contained in:
Tulir Asokan 2024-10-17 20:22:53 +03:00
parent d79be2b8cf
commit 1db1d2db5c
53 changed files with 6068 additions and 45 deletions

16
go.mod
View file

@ -5,33 +5,35 @@ go 1.23.0
toolchain go1.23.2
require (
github.com/chzyer/readline v1.5.1
github.com/coder/websocket v1.8.12
github.com/lucasb-eyer/go-colorful v1.2.0
github.com/mattn/go-sqlite3 v1.14.24
github.com/rivo/uniseg v0.4.7
github.com/rs/zerolog v1.33.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.7
go.mau.fi/util v0.8.1
go.mau.fi/zeroconfig v0.1.3
golang.org/x/crypto v0.28.0
golang.org/x/net v0.30.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297
mvdan.cc/xurls/v2 v2.5.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.7.7 // indirect
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sys v0.26.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)

13
go.sum
View file

@ -2,6 +2,12 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
@ -53,6 +59,7 @@ golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBn
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -66,5 +73,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455 h1:6UznUe9nDojckcvBNq0h1vZFM/KmA7Xs24YowG8iE+4=
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455/go.mod h1:7F/S6XAdyc/6DW+Q7xyFXRSPb6IjfqMb1OMepQ8C8OE=
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297 h1:8CybV+x9HPh4p41nIqJMKHI0bUF0g0bEozq49ytNVlc=
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297/go.mod h1:sjCZR1R/3NET/WjkcXPL6WpAHlWKku9HjRsdOkbM8Qw=
mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8=
mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE=

View file

@ -34,7 +34,8 @@ import (
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
"maunium.net/go/mautrix/hicli"
"go.mau.fi/gomuks/pkg/hicli"
)
type Gomuks struct {
@ -144,6 +145,7 @@ func (gmx *Gomuks) SetupLog() {
}
func (gmx *Gomuks) StartClient() {
hicli.HTMLSanitizerImgSrcTemplate = "_gomuks/media/%s/%s"
rawDB, err := dbutil.NewFromConfig("gomuks", dbutil.Config{
PoolConfig: dbutil.PoolConfig{
Type: "sqlite3-fk-wal",

View file

@ -27,7 +27,8 @@ import (
_ "go.mau.fi/util/dbutil/litestream"
flag "maunium.net/go/mauflag"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/hicli"
"go.mau.fi/gomuks/pkg/hicli"
)
var (

View file

@ -16,10 +16,10 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
var ErrBadGateway = mautrix.RespError{

373
pkg/hicli/LICENSE Normal file
View file

@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

64
pkg/hicli/cryptohelper.go Normal file
View file

@ -0,0 +1,64 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type hiCryptoHelper HiClient
var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil)
func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) {
roomMeta, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room metadata: %w", err)
} else if roomMeta == nil {
return nil, fmt.Errorf("unknown room")
}
return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content)
}
func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
return h.Crypto.DecryptMegolmEvent(ctx, evt)
}
func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
}
func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
userID: {deviceID},
h.Account.UserID: {"*"},
})
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Stringer("user_id", userID).
Msg("Failed to send room key request")
} else {
zerolog.Ctx(ctx).Debug().
Stringer("room_id", roomID).
Stringer("session_id", sessionID).
Stringer("user_id", userID).
Msg("Sent room key request")
}
}
func (h *hiCryptoHelper) Init(ctx context.Context) error {
return nil
}

View file

@ -0,0 +1,73 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"errors"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1`
putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2`
upsertAccountQuery = `
INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch)
VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id)
DO UPDATE SET device_id = excluded.device_id,
access_token = excluded.access_token,
homeserver_url = excluded.homeserver_url,
next_batch = excluded.next_batch
`
)
type AccountQuery struct {
*dbutil.QueryHelper[*Account]
}
func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) {
var exists bool
if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists {
return
}
err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) {
return aq.QueryOne(ctx, getAccountQuery, userID)
}
func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error {
return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID)
}
func (aq *AccountQuery) Put(ctx context.Context, account *Account) error {
return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...)
}
type Account struct {
UserID id.UserID
DeviceID id.DeviceID
AccessToken string
HomeserverURL string
NextBatch string
}
func (a *Account) Scan(row dbutil.Scannable) (*Account, error) {
return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch))
}
func (a *Account) sqlVariables() []any {
return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch}
}

View file

@ -0,0 +1,67 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"encoding/json"
"unsafe"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
upsertAccountDataQuery = `
INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
`
upsertRoomAccountDataQuery = `
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
`
)
type AccountDataQuery struct {
*dbutil.QueryHelper[*AccountData]
}
func unsafeJSONString(content json.RawMessage) *string {
if content == nil {
return nil
}
str := unsafe.String(unsafe.SliceData(content), len(content))
return &str
}
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
}
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
}
type AccountData struct {
UserID id.UserID `json:"user_id"`
RoomID id.RoomID `json:"room_id,omitempty"`
Type string `json:"type"`
Content json.RawMessage `json:"content"`
}
func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
var roomID sql.NullString
err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
if err != nil {
return nil, err
}
a.RoomID = id.RoomID(roomID.String)
return a, nil
}

View file

@ -0,0 +1,149 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"net/http"
"slices"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/id"
)
const (
insertCachedMediaQuery = `
INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (mxc) DO NOTHING
`
upsertCachedMediaQuery = `
INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (mxc) DO UPDATE
SET enc_file = excluded.enc_file,
file_name = excluded.file_name,
mime_type = excluded.mime_type,
size = excluded.size,
hash = excluded.hash,
error = excluded.error
WHERE excluded.error IS NULL OR cached_media.hash IS NULL
`
getCachedMediaQuery = `
SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error
FROM cached_media
WHERE mxc = $1
`
)
type CachedMediaQuery struct {
*dbutil.QueryHelper[*CachedMedia]
}
func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error {
return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...)
}
func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error {
return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...)
}
func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) {
return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc)
}
type MediaError struct {
Matrix *mautrix.RespError `json:"data"`
StatusCode int `json:"status_code"`
ReceivedAt jsontime.UnixMilli `json:"received_at"`
Attempts int `json:"attempts"`
}
const MaxMediaBackoff = 7 * 24 * time.Hour
func (me *MediaError) backoff() time.Duration {
return min(time.Duration(2<<me.Attempts)*time.Second, MaxMediaBackoff)
}
func (me *MediaError) UseCache() bool {
return me != nil && time.Since(me.ReceivedAt.Time) < me.backoff()
}
func (me *MediaError) Write(w http.ResponseWriter) {
if me.Matrix.ExtraData == nil {
me.Matrix.ExtraData = make(map[string]any)
}
me.Matrix.ExtraData["fi.mau.hicli.error_ts"] = me.ReceivedAt.UnixMilli()
me.Matrix.ExtraData["fi.mau.hicli.next_retry_ts"] = me.ReceivedAt.Add(me.backoff()).UnixMilli()
me.Matrix.WithStatus(me.StatusCode).Write(w)
}
type CachedMedia struct {
MXC id.ContentURI
EventRowID EventRowID
EncFile *attachment.EncryptedFile
FileName string
MimeType string
Size int64
Hash *[32]byte
Error *MediaError
}
func (c *CachedMedia) UseCache() bool {
return c != nil && (c.Hash != nil || c.Error.UseCache())
}
func (c *CachedMedia) sqlVariables() []any {
var hash []byte
if c.Hash != nil {
hash = c.Hash[:]
}
return []any{
&c.MXC, dbutil.NumPtr(c.EventRowID), dbutil.JSONPtr(c.EncFile),
dbutil.StrPtr(c.FileName), dbutil.StrPtr(c.MimeType), dbutil.NumPtr(c.Size),
hash, dbutil.JSONPtr(c.Error),
}
}
var safeMimes = []string{
"text/css", "text/plain", "text/csv",
"application/json", "application/ld+json",
"image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
"video/mp4", "video/webm", "video/ogg", "video/quicktime",
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
"audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
}
func (c *CachedMedia) Scan(row dbutil.Scannable) (*CachedMedia, error) {
var mimeType, fileName sql.NullString
var size, eventRowID sql.NullInt64
var hash []byte
err := row.Scan(&c.MXC, &eventRowID, dbutil.JSON{Data: &c.EncFile}, &fileName, &mimeType, &size, &hash, dbutil.JSON{Data: &c.Error})
if err != nil {
return nil, err
}
c.MimeType = mimeType.String
c.FileName = fileName.String
c.EventRowID = EventRowID(eventRowID.Int64)
c.Size = size.Int64
if len(hash) == 32 {
c.Hash = (*[32]byte)(hash)
}
return c, nil
}
func (c *CachedMedia) ContentDisposition() string {
if slices.Contains(safeMimes, c.MimeType) {
return "inline"
}
return "attachment"
}

View file

@ -0,0 +1,73 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"go.mau.fi/util/dbutil"
"go.mau.fi/gomuks/pkg/hicli/database/upgrades"
)
type Database struct {
*dbutil.Database
Account AccountQuery
AccountData AccountDataQuery
Room RoomQuery
Event EventQuery
CurrentState CurrentStateQuery
Timeline TimelineQuery
SessionRequest SessionRequestQuery
Receipt ReceiptQuery
CachedMedia CachedMediaQuery
}
func New(rawDB *dbutil.Database) *Database {
rawDB.UpgradeTable = upgrades.Table
eventQH := dbutil.MakeQueryHelper(rawDB, newEvent)
return &Database{
Database: rawDB,
Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)},
AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)},
Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)},
Event: EventQuery{QueryHelper: eventQH},
CurrentState: CurrentStateQuery{QueryHelper: eventQH},
Timeline: TimelineQuery{QueryHelper: eventQH},
SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)},
Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)},
CachedMedia: CachedMediaQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newCachedMedia)},
}
}
func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest {
return &SessionRequest{}
}
func newEvent(_ *dbutil.QueryHelper[*Event]) *Event {
return &Event{}
}
func newRoom(_ *dbutil.QueryHelper[*Room]) *Room {
return &Room{}
}
func newReceipt(_ *dbutil.QueryHelper[*Receipt]) *Receipt {
return &Receipt{}
}
func newCachedMedia(_ *dbutil.QueryHelper[*CachedMedia]) *CachedMedia {
return &CachedMedia{}
}
func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData {
return &AccountData{}
}
func newAccount(_ *dbutil.QueryHelper[*Account]) *Account {
return &Account{}
}

509
pkg/hicli/database/event.go Normal file
View file

@ -0,0 +1,509 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exgjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getEventBaseQuery = `
SELECT rowid, -1,
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM event
`
getEventByRowID = getEventBaseQuery + `WHERE rowid = $1`
getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)`
getEventByID = getEventBaseQuery + `WHERE event_id = $1`
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
insertEventBaseQuery = `
INSERT INTO event (
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
`
insertEventQuery = insertEventBaseQuery + `RETURNING rowid`
upsertEventQuery = insertEventBaseQuery + `
ON CONFLICT (event_id) DO UPDATE
SET decrypted=COALESCE(event.decrypted, excluded.decrypted),
decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type),
redacted_by=COALESCE(event.redacted_by, excluded.redacted_by),
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END,
send_error=excluded.send_error,
timestamp=excluded.timestamp,
unsigned=COALESCE(excluded.unsigned, event.unsigned),
local_content=COALESCE(excluded.local_content, event.local_content)
ON CONFLICT (transaction_id) DO UPDATE
SET event_id=excluded.event_id,
timestamp=excluded.timestamp,
unsigned=excluded.unsigned
RETURNING rowid
`
updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1`
updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1`
updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1`
updateEventLocalContentQuery = `UPDATE event SET local_content = $2 WHERE rowid = $1`
getEventReactionsQuery = getEventBaseQuery + `
WHERE room_id = ?
AND type = 'm.reaction'
AND relation_type = 'm.annotation'
AND redacted_by IS NULL
AND relates_to IN (%s)
`
getEventEditRowIDsQuery = `
SELECT main.event_id, edit.rowid
FROM event main
JOIN event edit ON
edit.room_id = main.room_id
AND edit.relates_to = main.event_id
AND edit.relation_type = 'm.replace'
AND edit.type = main.type
AND edit.sender = main.sender
AND edit.redacted_by IS NULL
WHERE main.event_id IN (%s)
ORDER BY main.event_id, edit.timestamp
`
setLastEditRowIDQuery = `
UPDATE event SET last_edit_rowid = $2 WHERE event_id = $1
`
updateReactionCountsQuery = `UPDATE event SET reactions = $2 WHERE event_id = $1`
)
type EventQuery struct {
*dbutil.QueryHelper[*Event]
}
func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) {
return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID)
}
func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, error) {
return eq.QueryOne(ctx, getEventByID, eventID)
}
func (eq *EventQuery) GetByRowID(ctx context.Context, rowID EventRowID) (*Event, error) {
return eq.QueryOne(ctx, getEventByRowID, rowID)
}
func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) {
query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID)
return eq.QueryMany(ctx, query, params...)
}
func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID)
if err == nil {
evt.RowID = rowID
}
return
}
func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
err = eq.GetDB().QueryRow(ctx, insertEventQuery, evt.sqlVariables()...).Scan(&rowID)
if err == nil {
evt.RowID = rowID
}
return
}
func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error {
return eq.Exec(ctx, updateEventIDQuery, rowID, newID)
}
func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sendError string) error {
return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError)
}
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error {
return eq.Exec(
ctx,
updateEventDecryptedQuery,
evt.RowID,
unsafeJSONString(evt.Decrypted),
evt.DecryptedType,
evt.UnreadType,
dbutil.JSONPtr(evt.LocalContent),
)
}
func (eq *EventQuery) UpdateLocalContent(ctx context.Context, evt *Event) error {
return eq.Exec(ctx, updateEventLocalContentQuery, evt.RowID, dbutil.JSONPtr(evt.LocalContent))
}
func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error {
eventIDs := make([]id.EventID, 0)
eventMap := make(map[id.EventID]*Event)
for i, evt := range events {
if evt.Reactions == nil {
eventIDs[i] = evt.ID
eventMap[evt.ID] = evt
}
}
result, err := eq.GetReactions(ctx, roomID, eventIDs...)
if err != nil {
return err
}
for evtID, res := range result {
eventMap[evtID].Reactions = res.Counts
}
return nil
}
func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error {
eventIDs := make([]id.EventID, len(events))
eventMap := make(map[id.EventID]*Event)
for i, evt := range events {
if evt.LastEditRowID == nil {
eventIDs[i] = evt.ID
eventMap[evt.ID] = evt
}
}
return eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
result, err := eq.GetEditRowIDs(ctx, roomID, eventIDs...)
if err != nil {
return err
}
for evtID, res := range result {
lastEditRowID := res[len(res)-1]
eventMap[evtID].LastEditRowID = &lastEditRowID
delete(eventMap, evtID)
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID)
if err != nil {
return err
}
}
var zero EventRowID
for evtID, evt := range eventMap {
evt.LastEditRowID = &zero
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, zero)
if err != nil {
return err
}
}
return nil
})
}
var reactionKeyPath = exgjson.Path("m.relates_to", "key")
type GetReactionsResult struct {
Events []*Event
Counts map[string]int
}
func buildMultiEventGetFunction[T any](preParams []any, eventIDs []T, query string) (string, []any) {
params := make([]any, len(preParams)+len(eventIDs))
copy(params, preParams)
for i, evtID := range eventIDs {
params[i+len(preParams)] = evtID
}
placeholders := strings.Repeat("?,", len(eventIDs))
placeholders = placeholders[:len(placeholders)-1]
return fmt.Sprintf(query, placeholders), params
}
type editRowIDTuple struct {
eventID id.EventID
editRowID EventRowID
}
func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) {
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventEditRowIDsQuery)
rows, err := eq.GetDB().Query(ctx, query, params...)
output := make(map[id.EventID][]EventRowID)
return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) {
err = row.Scan(&tuple.eventID, &tuple.editRowID)
return
}, err).Iter(func(tuple editRowIDTuple) (bool, error) {
output[tuple.eventID] = append(output[tuple.eventID], tuple.editRowID)
return true, nil
})
}
func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID]*GetReactionsResult, error) {
result := make(map[id.EventID]*GetReactionsResult, len(eventIDs))
for _, evtID := range eventIDs {
result[evtID] = &GetReactionsResult{Counts: make(map[string]int)}
}
return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventReactionsQuery)
events, err := eq.QueryMany(ctx, query, params...)
if err != nil {
return err
} else if len(events) == 0 {
return nil
}
for _, evt := range events {
dest := result[evt.RelatesTo]
dest.Events = append(dest.Events, evt)
keyRes := gjson.GetBytes(evt.Content, reactionKeyPath)
if keyRes.Type == gjson.String {
dest.Counts[keyRes.Str]++
}
}
for evtID, res := range result {
if len(res.Counts) > 0 {
err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts})
if err != nil {
return err
}
}
}
return nil
})
}
type EventRowID int64
func (m EventRowID) GetMassInsertValues() [1]any {
return [1]any{m}
}
type LocalContent struct {
SanitizedHTML string `json:"sanitized_html,omitempty"`
HTMLVersion int `json:"html_version,omitempty"`
}
type UnreadType int
func (ut UnreadType) Is(flag UnreadType) bool {
return ut&flag != 0
}
const (
UnreadTypeNone UnreadType = 0b0000
UnreadTypeNormal UnreadType = 0b0001
UnreadTypeNotify UnreadType = 0b0010
UnreadTypeHighlight UnreadType = 0b0100
UnreadTypeSound UnreadType = 0b1000
)
type Event struct {
RowID EventRowID `json:"rowid"`
TimelineRowID TimelineRowID `json:"timeline_rowid"`
RoomID id.RoomID `json:"room_id"`
ID id.EventID `json:"event_id"`
Sender id.UserID `json:"sender"`
Type string `json:"type"`
StateKey *string `json:"state_key,omitempty"`
Timestamp jsontime.UnixMilli `json:"timestamp"`
Content json.RawMessage `json:"content"`
Decrypted json.RawMessage `json:"decrypted,omitempty"`
DecryptedType string `json:"decrypted_type,omitempty"`
Unsigned json.RawMessage `json:"unsigned,omitempty"`
LocalContent *LocalContent `json:"local_content,omitempty"`
TransactionID string `json:"transaction_id,omitempty"`
RedactedBy id.EventID `json:"redacted_by,omitempty"`
RelatesTo id.EventID `json:"relates_to,omitempty"`
RelationType event.RelationType `json:"relation_type,omitempty"`
MegolmSessionID id.SessionID `json:"-,omitempty"`
DecryptionError string `json:"decryption_error,omitempty"`
SendError string `json:"send_error,omitempty"`
Reactions map[string]int `json:"reactions,omitempty"`
LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
UnreadType UnreadType `json:"unread_type,omitempty"`
}
func MautrixToEvent(evt *event.Event) *Event {
dbEvt := &Event{
RoomID: evt.RoomID,
ID: evt.ID,
Sender: evt.Sender,
Type: evt.Type.Type,
StateKey: evt.StateKey,
Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
Content: evt.Content.VeryRaw,
MegolmSessionID: getMegolmSessionID(evt),
TransactionID: evt.Unsigned.TransactionID,
}
if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") {
dbEvt.TransactionID = ""
}
dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt)
dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
if evt.Unsigned.RedactedBecause != nil {
dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
}
return dbEvt
}
func (e *Event) AsRawMautrix() *event.Event {
if e == nil {
return nil
}
evt := &event.Event{
RoomID: e.RoomID,
ID: e.ID,
Sender: e.Sender,
Type: event.Type{Type: e.Type, Class: event.MessageEventType},
StateKey: e.StateKey,
Timestamp: e.Timestamp.UnixMilli(),
Content: event.Content{VeryRaw: e.Content},
}
if e.Decrypted != nil {
evt.Content.VeryRaw = e.Decrypted
evt.Type.Type = e.DecryptedType
evt.Mautrix.WasEncrypted = true
}
if e.StateKey != nil {
evt.Type.Class = event.StateEventType
}
_ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
return evt
}
func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
var timestamp int64
var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString
err := row.Scan(
&e.RowID,
&e.TimelineRowID,
&e.RoomID,
&e.ID,
&e.Sender,
&e.Type,
&e.StateKey,
&timestamp,
(*[]byte)(&e.Content),
(*[]byte)(&e.Decrypted),
&decryptedType,
(*[]byte)(&e.Unsigned),
dbutil.JSON{Data: &e.LocalContent},
&transactionID,
&redactedBy,
&relatesTo,
&relationType,
&megolmSessionID,
&decryptionError,
&sendError,
dbutil.JSON{Data: &e.Reactions},
&e.LastEditRowID,
&e.UnreadType,
)
if err != nil {
return nil, err
}
e.Timestamp = jsontime.UM(time.UnixMilli(timestamp))
e.TransactionID = transactionID.String
e.RedactedBy = id.EventID(redactedBy.String)
e.RelatesTo = id.EventID(relatesTo.String)
e.RelationType = event.RelationType(relationType.String)
e.MegolmSessionID = id.SessionID(megolmSessionID.String)
e.DecryptedType = decryptedType.String
e.DecryptionError = decryptionError.String
e.SendError = sendError.String
return e, nil
}
var relatesToPath = exgjson.Path("m.relates_to", "event_id")
var relationTypePath = exgjson.Path("m.relates_to", "rel_type")
func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) {
if evt.StateKey != nil {
return "", ""
}
return GetRelatesToFromBytes(evt.Content.VeryRaw)
}
func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) {
results := gjson.GetManyBytes(content, relatesToPath, relationTypePath)
if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String {
return id.EventID(results[0].Str), event.RelationType(results[1].Str)
}
return "", ""
}
func getMegolmSessionID(evt *event.Event) id.SessionID {
if evt.Type != event.EventEncrypted {
return ""
}
res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
if res.Exists() && res.Type == gjson.String {
return id.SessionID(res.Str)
}
return ""
}
func (e *Event) sqlVariables() []any {
var reactions any
if e.Reactions != nil {
reactions = e.Reactions
}
return []any{
e.RoomID,
e.ID,
e.Sender,
e.Type,
e.StateKey,
e.Timestamp.UnixMilli(),
unsafeJSONString(e.Content),
unsafeJSONString(e.Decrypted),
dbutil.StrPtr(e.DecryptedType),
unsafeJSONString(e.Unsigned),
dbutil.JSONPtr(e.LocalContent),
dbutil.StrPtr(e.TransactionID),
dbutil.StrPtr(e.RedactedBy),
dbutil.StrPtr(e.RelatesTo),
dbutil.StrPtr(e.RelationType),
dbutil.StrPtr(e.MegolmSessionID),
dbutil.StrPtr(e.DecryptionError),
dbutil.StrPtr(e.SendError),
dbutil.JSON{Data: reactions},
e.LastEditRowID,
e.UnreadType,
}
}
func (e *Event) GetNonPushUnreadType() UnreadType {
if e.RelationType == event.RelReplace {
return UnreadTypeNone
}
switch e.Type {
case event.EventMessage.Type, event.EventSticker.Type:
return UnreadTypeNormal
case event.EventEncrypted.Type:
switch e.DecryptedType {
case event.EventMessage.Type, event.EventSticker.Type:
return UnreadTypeNormal
}
}
return UnreadTypeNone
}
func (e *Event) CanUseForPreview() bool {
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
(e.Type == event.EventEncrypted.Type &&
(e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) &&
e.RelationType != event.RelReplace && e.RedactedBy == ""
}
func (e *Event) BumpsSortingTimestamp() bool {
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
e.RelationType != event.RelReplace
}

View file

@ -0,0 +1,81 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"slices"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
upsertReceiptQuery = `
INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE
SET event_id = excluded.event_id,
timestamp = excluded.timestamp
`
)
var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
type ReceiptQuery struct {
*dbutil.QueryHelper[*Receipt]
}
func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error {
return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...)
}
func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error {
if len(receipts) > 1000 {
return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
for receiptChunk := range slices.Chunk(receipts, 200) {
err := rq.PutMany(ctx, roomID, receiptChunk...)
if err != nil {
return err
}
}
return nil
})
}
query, params := receiptMassInserter.Build([1]any{roomID}, receipts)
return rq.Exec(ctx, query, params...)
}
type Receipt struct {
RoomID id.RoomID `json:"room_id"`
UserID id.UserID `json:"user_id"`
ReceiptType event.ReceiptType `json:"receipt_type"`
ThreadID event.ThreadID `json:"thread_id"`
EventID id.EventID `json:"event_id"`
Timestamp jsontime.UnixMilli `json:"timestamp"`
}
func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) {
var ts int64
err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts)
if err != nil {
return nil, err
}
r.Timestamp = jsontime.UM(time.UnixMilli(ts))
return r, nil
}
func (r *Receipt) sqlVariables() []any {
return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
}
func (r *Receipt) GetMassInsertValues() [5]any {
return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
}

278
pkg/hicli/database/room.go Normal file
View file

@ -0,0 +1,278 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"errors"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getRoomBaseQuery = `
SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias,
lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp,
unread_highlights, unread_notifications, unread_messages, prev_batch
FROM room
`
getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2`
getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1`
ensureRoomExistsQuery = `
INSERT INTO room (room_id) VALUES ($1)
ON CONFLICT (room_id) DO NOTHING
`
upsertRoomFromSyncQuery = `
UPDATE room
SET creation_content = COALESCE(room.creation_content, $2),
name = COALESCE($3, room.name),
name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END,
avatar = COALESCE($5, room.avatar),
explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END,
topic = COALESCE($7, room.topic),
canonical_alias = COALESCE($8, room.canonical_alias),
lazy_load_summary = COALESCE($9, room.lazy_load_summary),
encryption_event = COALESCE($10, room.encryption_event),
has_member_list = room.has_member_list OR $11,
preview_event_rowid = COALESCE($12, room.preview_event_rowid),
sorting_timestamp = COALESCE($13, room.sorting_timestamp),
unread_highlights = COALESCE($14, room.unread_highlights),
unread_notifications = COALESCE($15, room.unread_notifications),
unread_messages = COALESCE($16, room.unread_messages),
prev_batch = COALESCE($17, room.prev_batch)
WHERE room_id = $1
`
setRoomPrevBatchQuery = `
UPDATE room SET prev_batch = $2 WHERE room_id = $1
`
updateRoomPreviewIfLaterOnTimelineQuery = `
UPDATE room
SET preview_event_rowid = $2
WHERE room_id = $1
AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1)
> COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0)
RETURNING preview_event_rowid
`
recalculateRoomPreviewEventQuery = `
SELECT rowid
FROM event
WHERE
room_id = $1
AND (type IN ('m.room.message', 'm.sticker')
OR (type = 'm.room.encrypted'
AND decrypted_type IN ('m.room.message', 'm.sticker')))
AND relation_type <> 'm.replace'
AND redacted_by IS NULL
ORDER BY timestamp DESC
LIMIT 1
`
)
type RoomQuery struct {
*dbutil.QueryHelper[*Room]
}
func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) {
return rq.QueryOne(ctx, getRoomByIDQuery, roomID)
}
func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) {
return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit)
}
func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error {
return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...)
}
func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error {
return rq.Exec(ctx, ensureRoomExistsQuery, roomID)
}
func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error {
return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch)
}
func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) {
var newPreviewRowID EventRowID
err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
} else if err == nil {
previewChanged = newPreviewRowID == rowID
}
return
}
func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) {
err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID)
return
}
type NameQuality int
const (
NameQualityNil NameQuality = iota
NameQualityParticipants
NameQualityCanonicalAlias
NameQualityExplicit
)
const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete"
type Room struct {
ID id.RoomID `json:"room_id"`
CreationContent *event.CreateEventContent `json:"creation_content,omitempty"`
Name *string `json:"name,omitempty"`
NameQuality NameQuality `json:"name_quality"`
Avatar *id.ContentURI `json:"avatar,omitempty"`
ExplicitAvatar bool `json:"explicit_avatar"`
Topic *string `json:"topic,omitempty"`
CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"`
LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"`
EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
HasMemberList bool `json:"has_member_list"`
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
UnreadHighlights int `json:"unread_highlights"`
UnreadNotifications int `json:"unread_notifications"`
UnreadMessages int `json:"unread_messages"`
PrevBatch string `json:"prev_batch"`
}
func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
if r.Name != nil && r.NameQuality >= other.NameQuality {
other.Name = r.Name
other.NameQuality = r.NameQuality
hasChanges = true
}
if r.Avatar != nil {
other.Avatar = r.Avatar
other.ExplicitAvatar = r.ExplicitAvatar
hasChanges = true
}
if r.Topic != nil {
other.Topic = r.Topic
hasChanges = true
}
if r.CanonicalAlias != nil {
other.CanonicalAlias = r.CanonicalAlias
hasChanges = true
}
if r.LazyLoadSummary != nil {
other.LazyLoadSummary = r.LazyLoadSummary
hasChanges = true
}
if r.EncryptionEvent != nil && other.EncryptionEvent == nil {
other.EncryptionEvent = r.EncryptionEvent
hasChanges = true
}
if r.HasMemberList && !other.HasMemberList {
hasChanges = true
other.HasMemberList = true
}
if r.PreviewEventRowID > other.PreviewEventRowID {
other.PreviewEventRowID = r.PreviewEventRowID
hasChanges = true
}
if r.SortingTimestamp.After(other.SortingTimestamp.Time) {
other.SortingTimestamp = r.SortingTimestamp
hasChanges = true
}
if r.UnreadHighlights != other.UnreadHighlights {
other.UnreadHighlights = r.UnreadHighlights
hasChanges = true
}
if r.UnreadNotifications != other.UnreadNotifications {
other.UnreadNotifications = r.UnreadNotifications
hasChanges = true
}
if r.UnreadMessages != other.UnreadMessages {
other.UnreadMessages = r.UnreadMessages
hasChanges = true
}
if r.PrevBatch != "" && other.PrevBatch == "" {
other.PrevBatch = r.PrevBatch
hasChanges = true
}
return
}
func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
var prevBatch sql.NullString
var previewEventRowID, sortingTimestamp sql.NullInt64
err := row.Scan(
&r.ID,
dbutil.JSON{Data: &r.CreationContent},
&r.Name,
&r.NameQuality,
&r.Avatar,
&r.ExplicitAvatar,
&r.Topic,
&r.CanonicalAlias,
dbutil.JSON{Data: &r.LazyLoadSummary},
dbutil.JSON{Data: &r.EncryptionEvent},
&r.HasMemberList,
&previewEventRowID,
&sortingTimestamp,
&r.UnreadHighlights,
&r.UnreadNotifications,
&r.UnreadMessages,
&prevBatch,
)
if err != nil {
return nil, err
}
r.PrevBatch = prevBatch.String
r.PreviewEventRowID = EventRowID(previewEventRowID.Int64)
r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64))
return r, nil
}
func (r *Room) sqlVariables() []any {
return []any{
r.ID,
dbutil.JSONPtr(r.CreationContent),
r.Name,
r.NameQuality,
r.Avatar,
r.ExplicitAvatar,
r.Topic,
r.CanonicalAlias,
dbutil.JSONPtr(r.LazyLoadSummary),
dbutil.JSONPtr(r.EncryptionEvent),
r.HasMemberList,
dbutil.NumPtr(r.PreviewEventRowID),
dbutil.UnixMilliPtr(r.SortingTimestamp.Time),
r.UnreadHighlights,
r.UnreadNotifications,
r.UnreadMessages,
dbutil.StrPtr(r.PrevBatch),
}
}
func (r *Room) BumpSortingTimestamp(evt *Event) bool {
if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) {
return false
}
r.SortingTimestamp = evt.Timestamp
now := time.Now()
if r.SortingTimestamp.After(now) {
r.SortingTimestamp = jsontime.UM(now)
}
return true
}

View file

@ -0,0 +1,68 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
putSessionRequestQueueEntry = `
INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (session_id) DO UPDATE
SET min_index = MIN(excluded.min_index, session_request.min_index),
backup_checked = excluded.backup_checked OR session_request.backup_checked,
request_sent = excluded.request_sent OR session_request.request_sent
`
removeSessionRequestQuery = `
DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2
`
getNextSessionsToRequestQuery = `
SELECT room_id, session_id, sender, min_index, backup_checked, request_sent
FROM session_request
WHERE request_sent = false OR backup_checked = false
ORDER BY backup_checked, rowid
LIMIT $1
`
)
type SessionRequestQuery struct {
*dbutil.QueryHelper[*SessionRequest]
}
func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) {
return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count)
}
func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error {
return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex)
}
func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error {
return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...)
}
type SessionRequest struct {
RoomID id.RoomID
SessionID id.SessionID
Sender id.UserID
MinIndex uint32
BackupChecked bool
RequestSent bool
}
func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) {
return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent))
}
func (s *SessionRequest) sqlVariables() []any {
return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent}
}

View file

@ -0,0 +1,94 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"fmt"
"slices"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
setCurrentStateQuery = `
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
`
addCurrentStateQuery = `
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
`
deleteCurrentStateQuery = `
DELETE FROM current_state WHERE room_id = $1
`
getCurrentRoomStateQuery = `
SELECT event.rowid, -1,
event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM current_state cs
JOIN event ON cs.event_rowid = event.rowid
WHERE cs.room_id = $1
`
getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3`
)
var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)")
const currentStateMassInsertBatchSize = 1000
type CurrentStateEntry struct {
EventType event.Type
StateKey string
EventRowID EventRowID
Membership event.Membership
}
func (cse *CurrentStateEntry) GetMassInsertValues() [4]any {
return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)}
}
type CurrentStateQuery struct {
*dbutil.QueryHelper[*Event]
}
func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
}
func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error {
var err error
if deleteOld {
err = csq.Exec(ctx, deleteCurrentStateQuery, roomID)
if err != nil {
return fmt.Errorf("failed to delete old state: %w", err)
}
}
for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) {
query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk)
err = csq.Exec(ctx, query, params...)
if err != nil {
return err
}
}
return nil
}
func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
}
func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) {
return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey)
}
func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) {
return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID)
}

View file

@ -0,0 +1,187 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"slices"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
const (
getMembershipQuery = `
SELECT membership FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2
`
getStateEventContentQuery = `
SELECT event.content FROM current_state cs
LEFT JOIN event ON event.rowid = cs.event_rowid
WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3
`
getRoomJoinedMembersQuery = `
SELECT state_key FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join'
`
getRoomJoinedOrInvitedMembersQuery = `
SELECT state_key FROM current_state
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
`
getHasFetchedMembersQuery = `
SELECT has_member_list FROM room WHERE room_id = $1
`
isRoomEncryptedQuery = `
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
`
getRoomEncryptionEventQuery = `
SELECT room.encryption_event FROM room WHERE room_id = $1
`
findSharedRoomsQuery = `
SELECT room_id FROM current_state
WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join'
`
)
type ClientStateStore struct {
*Database
}
var _ mautrix.StateStore = (*ClientStateStore)(nil)
var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil)
var _ crypto.StateStore = (*ClientStateStore)(nil)
func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return c.IsMembership(ctx, roomID, userID, event.MembershipJoin)
}
func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin)
}
func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
var membership event.Membership
err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership)
if errors.Is(err, sql.ErrNoRows) {
err = nil
membership = event.MembershipLeave
}
return slices.Contains(allowedMemberships, membership)
}
func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
content, err := c.TryGetMember(ctx, roomID, userID)
if content == nil {
content = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return content, err
}
func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) {
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
//TODO implement me
panic("implement me")
}
func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
}
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) {
//err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched)
//if errors.Is(err, sql.ErrNoRows) {
// err = nil
//}
//return
return false, fmt.Errorf("not implemented")
}
func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
return fmt.Errorf("not implemented")
}
func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
return nil, fmt.Errorf("not implemented")
}
func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) {
err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) {
err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID).
Scan(&dbutil.JSON{Data: &content})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
// TODO for multiuser support, this might need to filter by the local user's membership
rows, err := c.Query(ctx, findSharedRoomsQuery, userID)
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
}
// Update methods are all intentionally no-ops as the state store wants to have the full event
func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
return nil
}
func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
return nil
}
func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
return nil
}
func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
return nil
}
func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
return nil
}
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
return nil
}

View file

@ -0,0 +1,135 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package database
import (
"context"
"database/sql"
"errors"
"sync"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/id"
)
const (
clearTimelineQuery = `
DELETE FROM timeline WHERE room_id = $1
`
appendTimelineQuery = `
INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2)
ON CONFLICT DO NOTHING
RETURNING rowid, event_rowid
`
prependTimelineQuery = `
INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3)
`
checkTimelineContainsQuery = `
SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2)
`
findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline`
getTimelineQuery = `
SELECT event.rowid, timeline.rowid,
event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
FROM timeline
JOIN event ON event.rowid = timeline.event_rowid
WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2)
ORDER BY timeline.rowid DESC
LIMIT $3
`
)
type TimelineRowID int64
type TimelineRowTuple struct {
Timeline TimelineRowID `json:"timeline_rowid"`
Event EventRowID `json:"event_rowid"`
}
var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) {
err = row.Scan(&trt.Timeline, &trt.Event)
return
})
func (trt TimelineRowTuple) GetMassInsertValues() [2]any {
return [2]any{trt.Timeline, trt.Event}
}
var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)")
var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)")
type TimelineQuery struct {
*dbutil.QueryHelper[*Event]
minRowID TimelineRowID
minRowIDFound bool
prependLock sync.Mutex
}
// Clear clears the timeline of a given room.
func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error {
return tq.Exec(ctx, clearTimelineQuery, roomID)
}
func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) {
tq.prependLock.Lock()
defer tq.prependLock.Unlock()
if !tq.minRowIDFound {
err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return
}
if tq.minRowID >= 0 {
// No negative row IDs exist, start at -2
tq.minRowID = -2
} else {
// We fetched the lowest row ID, but we want the next available one, so decrement one
tq.minRowID--
}
tq.minRowIDFound = true
}
startFrom = tq.minRowID
tq.minRowID -= TimelineRowID(count)
return
}
// Prepend adds the given event row IDs to the beginning of the timeline.
// The events must be sorted in reverse chronological order (newest event first).
func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) {
var startFrom TimelineRowID
startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs))
if err != nil {
return
}
prependEntries = make([]TimelineRowTuple, len(rowIDs))
for i, rowID := range rowIDs {
prependEntries[i] = TimelineRowTuple{
Timeline: startFrom - TimelineRowID(i),
Event: rowID,
}
}
query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries)
err = tq.Exec(ctx, query, params...)
return
}
// Append adds the given event row IDs to the end of the timeline.
func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) {
query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs)
return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList()
}
func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) {
return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit)
}
func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) {
err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists)
return
}

View file

@ -0,0 +1,255 @@
-- v0 -> v3 (compatible with v1+): Latest revision
CREATE TABLE account (
user_id TEXT NOT NULL PRIMARY KEY,
device_id TEXT NOT NULL,
access_token TEXT NOT NULL,
homeserver_url TEXT NOT NULL,
next_batch TEXT NOT NULL
) STRICT;
CREATE TABLE room (
room_id TEXT NOT NULL PRIMARY KEY,
creation_content TEXT,
name TEXT,
name_quality INTEGER NOT NULL DEFAULT 0,
avatar TEXT,
explicit_avatar INTEGER NOT NULL DEFAULT 0,
topic TEXT,
canonical_alias TEXT,
lazy_load_summary TEXT,
encryption_event TEXT,
has_member_list INTEGER NOT NULL DEFAULT false,
preview_event_rowid INTEGER,
sorting_timestamp INTEGER,
unread_highlights INTEGER NOT NULL DEFAULT 0,
unread_notifications INTEGER NOT NULL DEFAULT 0,
unread_messages INTEGER NOT NULL DEFAULT 0,
prev_batch TEXT,
CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
) STRICT;
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC);
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0);
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0);
CREATE TABLE account_data (
user_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL,
PRIMARY KEY (user_id, type)
) STRICT;
CREATE TABLE room_account_data (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL,
PRIMARY KEY (user_id, room_id, type),
CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id);
CREATE TABLE event (
rowid INTEGER PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
sender TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT,
timestamp INTEGER NOT NULL,
content TEXT NOT NULL,
decrypted TEXT,
decrypted_type TEXT,
unsigned TEXT NOT NULL,
local_content TEXT,
transaction_id TEXT,
redacted_by TEXT,
relates_to TEXT,
relation_type TEXT,
megolm_session_id TEXT,
decryption_error TEXT,
send_error TEXT,
reactions TEXT,
last_edit_rowid INTEGER,
unread_type INTEGER NOT NULL DEFAULT 0,
CONSTRAINT event_id_unique_key UNIQUE (event_id),
CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id),
CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE INDEX event_room_id_idx ON event (room_id);
CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by);
CREATE INDEX event_relates_to_idx ON event (room_id, relates_to);
CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id);
CREATE TRIGGER event_update_redacted_by
AFTER INSERT
ON event
WHEN NEW.type = 'm.room.redaction'
BEGIN
UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts';
END;
CREATE TRIGGER event_update_last_edit_when_redacted
AFTER UPDATE
ON event
WHEN OLD.redacted_by IS NULL
AND NEW.redacted_by IS NOT NULL
AND NEW.relation_type = 'm.replace'
AND NEW.state_key IS NULL
BEGIN
UPDATE event
SET last_edit_rowid = COALESCE(
(SELECT rowid
FROM event edit
WHERE edit.room_id = event.room_id
AND edit.relates_to = event.event_id
AND edit.relation_type = 'm.replace'
AND edit.type = event.type
AND edit.sender = event.sender
AND edit.redacted_by IS NULL
AND edit.state_key IS NULL
ORDER BY edit.timestamp DESC
LIMIT 1),
0)
WHERE event_id = NEW.relates_to
AND last_edit_rowid = NEW.rowid
AND state_key IS NULL
AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'));
END;
CREATE TRIGGER event_insert_update_last_edit
AFTER INSERT
ON event
WHEN NEW.relation_type = 'm.replace'
AND NEW.redacted_by IS NULL
AND NEW.state_key IS NULL
BEGIN
UPDATE event
SET last_edit_rowid = NEW.rowid
WHERE event_id = NEW.relates_to
AND type = NEW.type
AND sender = NEW.sender
AND state_key IS NULL
AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'))
AND NEW.timestamp >
COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0);
END;
CREATE TRIGGER event_insert_fill_reactions
AFTER INSERT
ON event
WHEN NEW.type = 'm.reaction'
AND NEW.relation_type = 'm.annotation'
AND NEW.redacted_by IS NULL
AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
BEGIN
UPDATE event
SET reactions=json_set(
reactions,
'$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
coalesce(
reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
0
) + 1)
WHERE event_id = NEW.relates_to
AND reactions IS NOT NULL;
END;
CREATE TRIGGER event_redact_fill_reactions
AFTER UPDATE
ON event
WHEN NEW.type = 'm.reaction'
AND NEW.relation_type = 'm.annotation'
AND NEW.redacted_by IS NOT NULL
AND OLD.redacted_by IS NULL
AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
BEGIN
UPDATE event
SET reactions=json_set(
reactions,
'$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
coalesce(
reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
0
) - 1)
WHERE event_id = NEW.relates_to
AND reactions IS NOT NULL;
END;
CREATE TABLE cached_media (
mxc TEXT NOT NULL PRIMARY KEY,
event_rowid INTEGER,
enc_file TEXT,
file_name TEXT,
mime_type TEXT,
size INTEGER,
hash BLOB,
error TEXT,
CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
) STRICT;
CREATE TABLE session_request (
room_id TEXT NOT NULL,
session_id TEXT NOT NULL,
sender TEXT NOT NULL,
min_index INTEGER NOT NULL,
backup_checked INTEGER NOT NULL DEFAULT false,
request_sent INTEGER NOT NULL DEFAULT false,
PRIMARY KEY (session_id),
CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
) STRICT;
CREATE INDEX session_request_room_idx ON session_request (room_id);
CREATE TABLE timeline (
rowid INTEGER PRIMARY KEY,
room_id TEXT NOT NULL,
event_rowid INTEGER NOT NULL,
CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE,
CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid)
) STRICT;
CREATE INDEX timeline_room_id_idx ON timeline (room_id);
CREATE TABLE current_state (
room_id TEXT NOT NULL,
event_type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_rowid INTEGER NOT NULL,
membership TEXT,
PRIMARY KEY (room_id, event_type, state_key),
CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid)
) STRICT, WITHOUT ROWID;
CREATE TABLE receipt (
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
thread_id TEXT NOT NULL,
event_id TEXT NOT NULL,
timestamp INTEGER NOT NULL,
PRIMARY KEY (room_id, user_id, receipt_type, thread_id),
CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
-- note: there's no foreign key on event ID because receipts could point at events that are too far in history.
) STRICT;

View file

@ -0,0 +1,2 @@
-- v2 (compatible with v1+): Add explicit avatar flag to rooms
ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0;

View file

@ -0,0 +1,6 @@
-- v3 (compatible with v1+): Add more fields to events
ALTER TABLE event ADD COLUMN local_content TEXT;
ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0;
ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0;

View file

@ -0,0 +1,22 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package upgrades
import (
"embed"
"go.mau.fi/util/dbutil"
)
var Table dbutil.UpgradeTable
//go:embed *.sql
var upgrades embed.FS
func init() {
Table.RegisterFS(upgrades)
}

View file

@ -0,0 +1,209 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"sync"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
if err != nil {
return nil, err
} else if data == nil {
return nil, nil
}
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
if err != nil {
return nil, err
}
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
}
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
log := zerolog.Ctx(ctx)
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
if err != nil {
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
}
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
if err != nil {
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
return
} else if len(events) == 0 {
log.Trace().Msg("No events to retry decryption for")
return
}
decrypted := events[:0]
for _, evt := range events {
if evt.Decrypted != nil {
continue
}
var mautrixEvt *event.Event
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
if err != nil {
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
} else {
decrypted = append(decrypted, evt)
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
}
}
if len(decrypted) > 0 {
var newPreview database.EventRowID
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
for _, evt := range decrypted {
err = h.DB.Event.UpdateDecrypted(ctx, evt)
if err != nil {
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
}
if evt.CanUseForPreview() {
var previewChanged bool
previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID)
if err != nil {
return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err)
} else if previewChanged {
newPreview = evt.RowID
}
}
}
return nil
})
if err != nil {
log.Err(err).Msg("Failed to save decrypted events")
} else {
h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID})
}
}
}
func (h *HiClient) WakeupRequestQueue() {
select {
case h.requestQueueWakeup <- struct{}{}:
default:
}
}
func (h *HiClient) RunRequestQueue(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
ctx = log.WithContext(ctx)
log.Info().Msg("Starting key request queue")
defer func() {
log.Info().Msg("Stopping key request queue")
}()
for {
err := h.FetchKeysForOutdatedUsers(ctx)
if err != nil {
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
}
madeRequests, err := h.RequestQueuedSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to handle session request queue")
} else if madeRequests {
continue
}
select {
case <-ctx.Done():
return
case <-h.requestQueueWakeup:
}
}
}
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
defer doneFunc()
log := zerolog.Ctx(ctx)
if !req.BackupChecked {
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to fetch session from key backup")
// TODO should this have retries instead of just storing it's checked?
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
}
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
req.BackupChecked = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
}
} else {
log.Debug().Stringer("session_id", req.SessionID).
Msg("Found session with sufficiently low first known index, removing from queue")
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
}
}
} else {
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
h.Account.UserID: {"*"},
req.Sender: {"*"},
})
//var err error
if err != nil {
log.Err(err).
Stringer("session_id", req.SessionID).
Msg("Failed to send key request")
} else {
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
req.RequestSent = true
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
}
}
}
}
const MaxParallelRequests = 5
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
if err != nil {
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
} else if len(sessions) == 0 {
return false, nil
}
var wg sync.WaitGroup
wg.Add(len(sessions))
for _, req := range sessions {
go h.requestQueuedSession(ctx, req, wg.Done)
}
wg.Wait()
return true, err
}
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
if err != nil {
return err
} else if len(outdatedUsers) == 0 {
return nil
}
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
if err != nil {
return err
}
// TODO backoff for users that fail to be fetched?
return nil
}

60
pkg/hicli/events.go Normal file
View file

@ -0,0 +1,60 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type SyncRoom struct {
Meta *database.Room `json:"meta"`
Timeline []database.TimelineRowTuple `json:"timeline"`
State map[event.Type]map[string]database.EventRowID `json:"state"`
Events []*database.Event `json:"events"`
Reset bool `json:"reset"`
Notifications []SyncNotification `json:"notifications"`
}
type SyncNotification struct {
RowID database.EventRowID `json:"event_rowid"`
Sound bool `json:"sound"`
}
type SyncComplete struct {
Rooms map[id.RoomID]*SyncRoom `json:"rooms"`
}
func (c *SyncComplete) IsEmpty() bool {
return len(c.Rooms) == 0
}
type EventsDecrypted struct {
RoomID id.RoomID `json:"room_id"`
PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"`
Events []*database.Event `json:"events"`
}
type Typing struct {
RoomID id.RoomID `json:"room_id"`
event.TypingEventContent
}
type SendComplete struct {
Event *database.Event `json:"event"`
Error error `json:"error"`
}
type ClientState struct {
IsLoggedIn bool `json:"is_logged_in"`
IsVerified bool `json:"is_verified"`
UserID id.UserID `json:"user_id,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
HomeserverURL string `json:"homeserver_url,omitempty"`
}

251
pkg/hicli/hicli.go Normal file
View file

@ -0,0 +1,251 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
package hicli
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type HiClient struct {
DB *database.Database
Account *database.Account
Client *mautrix.Client
Crypto *crypto.OlmMachine
CryptoStore *crypto.SQLCryptoStore
ClientStore *database.ClientStateStore
Log zerolog.Logger
Verified bool
KeyBackupVersion id.KeyBackupVersion
KeyBackupKey *backup.MegolmBackupKey
PushRules atomic.Pointer[pushrules.PushRuleset]
EventHandler func(evt any)
firstSyncReceived bool
syncingID int
syncLock sync.Mutex
stopSync atomic.Pointer[context.CancelFunc]
encryptLock sync.Mutex
requestQueueWakeup chan struct{}
jsonRequestsLock sync.Mutex
jsonRequests map[int64]context.CancelCauseFunc
paginationInterrupterLock sync.Mutex
paginationInterrupter map[id.RoomID]context.CancelCauseFunc
}
var ErrTimelineReset = errors.New("got limited timeline sync response")
func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient {
if cryptoDB == nil {
cryptoDB = rawDB
}
if rawDB.Owner == "" {
rawDB.Owner = "hicli"
rawDB.IgnoreForeignTables = true
}
if rawDB.Log == nil {
rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
}
db := database.New(rawDB)
c := &HiClient{
DB: db,
Log: log,
requestQueueWakeup: make(chan struct{}, 1),
jsonRequests: make(map[int64]context.CancelCauseFunc),
paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc),
EventHandler: evtHandler,
}
c.ClientStore = &database.ClientStateStore{Database: db}
c.Client = &mautrix.Client{
UserAgent: mautrix.DefaultUserAgent,
Client: &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
// This needs to be relatively high to allow initial syncs
ResponseHeaderTimeout: 180 * time.Second,
ForceAttemptHTTP2: true,
},
Timeout: 180 * time.Second,
},
Syncer: (*hiSyncer)(c),
Store: (*hiStore)(c),
StateStore: c.ClientStore,
Log: log.With().Str("component", "mautrix client").Logger(),
}
c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
cryptoLog := log.With().Str("component", "crypto").Logger()
c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
c.Crypto.SessionReceived = c.handleReceivedMegolmSession
c.Crypto.DisableRatchetTracking = true
c.Crypto.DisableDecryptKeyFetching = true
c.Client.Crypto = (*hiCryptoHelper)(c)
return c
}
func (h *HiClient) IsLoggedIn() bool {
return h.Account != nil
}
func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error {
if expectedAccount != nil && userID != expectedAccount.UserID {
panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID"))
}
err := h.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade hicli db: %w", err)
}
err = h.CryptoStore.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade crypto db: %w", err)
}
account, err := h.DB.Account.Get(ctx, userID)
if err != nil {
return err
} else if account == nil && expectedAccount != nil {
err = h.DB.Account.Put(ctx, expectedAccount)
if err != nil {
return err
}
account = expectedAccount
} else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID {
return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID)
}
if account != nil {
zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
h.Account = account
h.CryptoStore.AccountID = account.UserID.String()
h.CryptoStore.DeviceID = account.DeviceID
h.Client.UserID = account.UserID
h.Client.DeviceID = account.DeviceID
h.Client.AccessToken = account.AccessToken
h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
if err != nil {
return err
}
err = h.CheckServerVersions(ctx)
if err != nil {
return err
}
err = h.Crypto.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm machine: %w", err)
}
h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
if err != nil {
return err
}
zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
if h.Verified {
err = h.loadPrivateKeys(ctx)
if err != nil {
return err
}
go h.Sync()
}
}
return nil
}
var ErrFailedToCheckServerVersions = errors.New("failed to check server versions")
var ErrOutdatedServer = errors.New("homeserver is outdated")
var MinimumSpecVersion = mautrix.SpecV11
func (h *HiClient) CheckServerVersions(ctx context.Context) error {
versions, err := h.Client.Versions(ctx)
if err != nil {
return exerrors.NewDualError(ErrFailedToCheckServerVersions, err)
} else if !versions.Contains(MinimumSpecVersion) {
return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest())
}
return nil
}
func (h *HiClient) IsSyncing() bool {
return h.stopSync.Load() != nil
}
func (h *HiClient) Sync() {
h.Client.StopSync()
if fn := h.stopSync.Load(); fn != nil {
(*fn)()
}
h.syncLock.Lock()
defer h.syncLock.Unlock()
h.syncingID++
syncingID := h.syncingID
log := h.Log.With().
Str("action", "sync").
Int("sync_id", syncingID).
Logger()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h.stopSync.Store(&cancel)
go h.RunRequestQueue(h.Log.WithContext(ctx))
go h.LoadPushRules(h.Log.WithContext(ctx))
ctx = log.WithContext(ctx)
log.Info().Msg("Starting syncing")
err := h.Client.SyncWithContext(ctx)
if err != nil && ctx.Err() == nil {
log.Err(err).Msg("Fatal error in syncer")
} else {
log.Info().Msg("Syncing stopped")
}
}
func (h *HiClient) LoadPushRules(ctx context.Context) {
rules, err := h.Client.GetPushRules(ctx)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules")
return
}
h.PushRules.Store(rules)
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch")
}
func (h *HiClient) Stop() {
h.Client.StopSync()
if fn := h.stopSync.Swap(nil); fn != nil {
(*fn)()
}
h.syncLock.Lock()
//lint:ignore SA2001 just acquire the lock to make sure Sync is done
h.syncLock.Unlock()
err := h.DB.Close()
if err != nil {
h.Log.Err(err).Msg("Failed to close database cleanly")
}
}

110
pkg/hicli/hitest/hitest.go Normal file
View file

@ -0,0 +1,110 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package main
import (
"context"
"fmt"
"io"
"strings"
"github.com/chzyer/readline"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
"go.mau.fi/zeroconfig"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli"
)
var writerTypeReadline zeroconfig.WriterType = "hitest_readline"
func main() {
hicli.InitialDeviceDisplayName = "mautrix hitest"
rl := exerrors.Must(readline.New("> "))
defer func() {
_ = rl.Close()
}()
zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) {
return rl.Stdout(), nil
})
debug := zerolog.DebugLevel
log := exerrors.Must((&zeroconfig.Config{
MinLevel: &debug,
Writers: []zeroconfig.WriterConfig{{
Type: writerTypeReadline,
Format: zeroconfig.LogFormatPrettyColored,
}},
}).Compile())
exzerolog.SetupDefaults(log)
rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal"))
ctx := log.WithContext(context.Background())
cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) {
_, _ = fmt.Fprintf(rl, "Received event of type %T\n", a)
switch evt := a.(type) {
case *hicli.SyncComplete:
for _, room := range evt.Rooms {
name := "name unset"
if room.Meta.Name != nil {
name = *room.Meta.Name
}
_, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID)
_, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp)
_, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset)
}
case *hicli.EventsDecrypted:
for _, decrypted := range evt.Events {
_, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted)
}
if evt.PreviewEventRowID != 0 {
_, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID)
}
case *hicli.Typing:
_, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs)
}
})
userID, _ := cli.DB.Account.GetFirstUserID(ctx)
exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil))
if !cli.IsLoggedIn() {
rl.SetPrompt("User ID: ")
userID := id.UserID(exerrors.Must(rl.Readline()))
_, serverName := exerrors.Must2(userID.Parse())
discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName))
password := exerrors.Must(rl.ReadPassword("Password: "))
recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: "))
exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode)))
}
rl.SetPrompt("> ")
for {
line, err := rl.Readline()
if err != nil {
break
}
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
switch strings.ToLower(fields[0]) {
case "send":
resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{
Body: strings.Join(fields[2:], " "),
MsgType: event.MsgText,
})
_, _ = fmt.Fprintln(rl, err)
_, _ = fmt.Fprintf(rl, "%+v\n", resp)
}
}
cli.Stop()
}

493
pkg/hicli/html.go Normal file
View file

@ -0,0 +1,493 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"bytes"
"errors"
"fmt"
"io"
"net/url"
"regexp"
"slices"
"strconv"
"strings"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"maunium.net/go/mautrix/id"
"mvdan.cc/xurls/v2"
)
func tagIsAllowed(tag atom.Atom) bool {
switch tag {
case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P,
atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong,
atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody,
atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img,
atom.Details, atom.Summary:
return true
default:
return false
}
}
func isSelfClosing(tag atom.Atom) bool {
switch tag {
case atom.Img, atom.Br, atom.Hr:
return true
default:
return false
}
}
var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`)
var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`)
// This is approximately a mirror of web/src/util/mediasize.ts in gomuks
func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) {
if widthInt <= 0 || heightInt <= 0 {
return
}
width = float64(widthInt)
height = float64(heightInt)
const imageContainerWidth float64 = 320
const imageContainerHeight float64 = 240
const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight
if width > imageContainerWidth || height > imageContainerHeight {
aspectRatio := width / height
if aspectRatio > imageContainerAspectRatio {
width = imageContainerWidth
height = imageContainerWidth / aspectRatio
} else if aspectRatio < imageContainerAspectRatio {
width = imageContainerHeight * aspectRatio
height = imageContainerHeight
} else {
width = imageContainerWidth
height = imageContainerHeight
}
}
ok = true
return
}
func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) {
for _, attr := range attrs {
switch attr.Key {
case "src":
src = attr.Val
case "alt":
alt = attr.Val
case "title":
title = attr.Val
case "data-mx-emoticon":
isCustomEmoji = true
case "width":
width, _ = strconv.Atoi(attr.Val)
case "height":
height, _ = strconv.Atoi(attr.Val)
}
}
return
}
func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) {
for _, attr := range attrs {
switch attr.Key {
case "data-mx-bg-color":
if allowedColorRegex.MatchString(attr.Val) {
bgColor = attr.Val
}
case "data-mx-color", "color":
if allowedColorRegex.MatchString(attr.Val) {
textColor = attr.Val
}
case "data-mx-spoiler":
spoiler = attr.Val
isSpoiler = true
case "data-mx-maths":
maths = attr.Val
}
}
return
}
func parseAAttributes(attrs []html.Attribute) (href string) {
for _, attr := range attrs {
switch attr.Key {
case "href":
href = strings.TrimSpace(attr.Val)
}
}
return
}
func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool {
switch tag {
case atom.Ol:
switch attr.Key {
case "start":
_, err := strconv.Atoi(attr.Val)
return err == nil
}
case atom.Code:
switch attr.Key {
case "class":
return languageRegex.MatchString(attr.Val)
}
case atom.Div:
switch attr.Key {
case "data-mx-maths":
return true
}
}
return false
}
// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them.
var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`)
func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) {
for i, item := range items {
if item[0] >= minIndex {
return i, item[0], item[1], true
}
}
return -1, -1, -1, false
}
func writeMention(w *strings.Builder, mention []byte) {
uri := &id.MatrixURI{
Sigil1: rune(mention[0]),
MXID1: string(mention[1:]),
}
w.WriteString(`<a`)
writeAttribute(w, "href", uri.String())
writeAttribute(w, "class", matrixURIClassName(uri)+" hicli-matrix-uri-plaintext")
w.WriteByte('>')
writeEscapedBytes(w, mention)
w.WriteString("</a>")
}
func writeURL(w *strings.Builder, addr []byte) {
parsedURL, err := url.Parse(string(addr))
if err != nil {
writeEscapedBytes(w, addr)
return
}
if parsedURL.Scheme == "" {
parsedURL.Scheme = "https"
}
w.WriteString(`<a target="_blank" rel="noreferrer noopener"`)
writeAttribute(w, "href", parsedURL.String())
w.WriteByte('>')
writeEscapedBytes(w, addr)
w.WriteString("</a>")
}
func linkifyAndWriteBytes(w *strings.Builder, s []byte) {
mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1)
urls := xurls.Relaxed().FindAllIndex(s, -1)
minIndex := 0
for {
mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex)
urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex)
if hasMention && (!hasURL || nextMentionStart <= nextURLStart) {
writeEscapedBytes(w, s[minIndex:nextMentionStart])
writeMention(w, s[nextMentionStart:nextMentionEnd])
minIndex = nextMentionEnd
mentions = mentions[mentionIdx:]
} else if hasURL && (!hasMention || nextURLStart < nextMentionStart) {
writeEscapedBytes(w, s[minIndex:nextURLStart])
writeURL(w, s[nextURLStart:nextURLEnd])
minIndex = nextURLEnd
urls = urls[urlIdx:]
} else {
break
}
}
writeEscapedBytes(w, s[minIndex:])
}
const escapedChars = "&'<>\"\r"
func writeEscapedBytes(w *strings.Builder, s []byte) {
i := bytes.IndexAny(s, escapedChars)
for i != -1 {
w.Write(s[:i])
var esc string
switch s[i] {
case '&':
esc = "&amp;"
case '\'':
// "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
esc = "&#39;"
case '<':
esc = "&lt;"
case '>':
esc = "&gt;"
case '"':
// "&#34;" is shorter than "&quot;".
esc = "&#34;"
case '\r':
esc = "&#13;"
default:
panic("unrecognized escape character")
}
s = s[i+1:]
w.WriteString(esc)
i = bytes.IndexAny(s, escapedChars)
}
w.Write(s)
}
func writeEscapedString(w *strings.Builder, s string) {
i := strings.IndexAny(s, escapedChars)
for i != -1 {
w.WriteString(s[:i])
var esc string
switch s[i] {
case '&':
esc = "&amp;"
case '\'':
// "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
esc = "&#39;"
case '<':
esc = "&lt;"
case '>':
esc = "&gt;"
case '"':
// "&#34;" is shorter than "&quot;".
esc = "&#34;"
case '\r':
esc = "&#13;"
default:
panic("unrecognized escape character")
}
s = s[i+1:]
w.WriteString(esc)
i = strings.IndexAny(s, escapedChars)
}
w.WriteString(s)
}
func writeAttribute(w *strings.Builder, key, value string) {
w.WriteByte(' ')
w.WriteString(key)
w.WriteString(`="`)
writeEscapedString(w, value)
w.WriteByte('"')
}
func matrixURIClassName(uri *id.MatrixURI) string {
switch uri.Sigil1 {
case '@':
return "hicli-matrix-uri hicli-matrix-uri-user"
case '#':
return "hicli-matrix-uri hicli-matrix-uri-room-alias"
case '!':
if uri.Sigil2 == '$' {
return "hicli-matrix-uri hicli-matrix-uri-event-id"
}
return "hicli-matrix-uri hicli-matrix-uri-room-id"
default:
return "hicli-matrix-uri hicli-matrix-uri-unknown"
}
}
func writeA(w *strings.Builder, attr []html.Attribute) {
w.WriteString("<a")
href := parseAAttributes(attr)
if href == "" {
return
}
parsedURL, err := url.Parse(href)
if err != nil {
return
}
newTab := true
switch parsedURL.Scheme {
case "bitcoin", "ftp", "geo", "http", "im", "irc", "ircs", "magnet", "mailto",
"mms", "news", "nntp", "openpgp4fpr", "sip", "sftp", "sms", "smsto", "ssh",
"tel", "urn", "webcal", "wtai", "xmpp":
// allowed
case "https":
if parsedURL.Host == "matrix.to" {
uri, err := id.ProcessMatrixToURL(parsedURL)
if err != nil {
return
}
href = uri.String()
newTab = false
writeAttribute(w, "class", matrixURIClassName(uri))
}
case "matrix":
uri, err := id.ProcessMatrixURI(parsedURL)
if err != nil {
return
}
href = uri.String()
newTab = false
writeAttribute(w, "class", matrixURIClassName(uri))
case "mxc":
mxc := id.ContentURIString(href).ParseOrIgnore()
if !mxc.IsValid() {
return
}
href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)
default:
return
}
writeAttribute(w, "href", href)
if newTab {
writeAttribute(w, "target", "_blank")
writeAttribute(w, "rel", "noreferrer noopener")
}
}
var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s"
func writeImg(w *strings.Builder, attr []html.Attribute) {
src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr)
w.WriteString("<img")
writeAttribute(w, "alt", alt)
if title != "" {
writeAttribute(w, "title", title)
}
mxc := id.ContentURIString(src).ParseOrIgnore()
if !mxc.IsValid() {
return
}
writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID))
writeAttribute(w, "loading", "lazy")
if isCustomEmoji {
writeAttribute(w, "class", "hicli-custom-emoji")
} else if cWidth, cHeight, sizeOK := calculateMediaSize(width, height); sizeOK {
writeAttribute(w, "class", "hicli-sized-inline-img")
writeAttribute(w, "style", fmt.Sprintf("width: %.2fpx; height: %.2fpx;", cWidth, cHeight))
} else {
writeAttribute(w, "class", "hicli-sizeless-inline-img")
}
}
func writeSpan(w *strings.Builder, attr []html.Attribute) {
bgColor, textColor, spoiler, _, isSpoiler := parseSpanAttributes(attr)
if isSpoiler && spoiler != "" {
w.WriteString(`<span class="spoiler-reason">`)
w.WriteString(spoiler)
w.WriteString(" </span>")
}
w.WriteByte('<')
w.WriteString("span")
if isSpoiler {
writeAttribute(w, "class", "hicli-spoiler")
}
var style string
if bgColor != "" {
style += fmt.Sprintf("background-color: %s;", bgColor)
}
if textColor != "" {
style += fmt.Sprintf("color: %s;", textColor)
}
if style != "" {
writeAttribute(w, "style", style)
}
}
type tagStack []atom.Atom
func (ts *tagStack) contains(tags ...atom.Atom) bool {
for i := len(*ts) - 1; i >= 0; i-- {
for _, tag := range tags {
if (*ts)[i] == tag {
return true
}
}
}
return false
}
func (ts *tagStack) push(tag atom.Atom) {
*ts = append(*ts, tag)
}
func (ts *tagStack) pop(tag atom.Atom) bool {
if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag {
*ts = (*ts)[:len(*ts)-1]
return true
}
return false
}
func sanitizeAndLinkifyHTML(body string) (string, error) {
tz := html.NewTokenizer(strings.NewReader(body))
var built strings.Builder
ts := make(tagStack, 2)
Loop:
for {
switch tz.Next() {
case html.ErrorToken:
err := tz.Err()
if errors.Is(err, io.EOF) {
break Loop
}
return "", err
case html.StartTagToken, html.SelfClosingTagToken:
token := tz.Token()
if !tagIsAllowed(token.DataAtom) {
continue
}
tagIsSelfClosing := isSelfClosing(token.DataAtom)
if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing {
continue
}
switch token.DataAtom {
case atom.A:
writeA(&built, token.Attr)
case atom.Img:
writeImg(&built, token.Attr)
case atom.Span, atom.Font:
writeSpan(&built, token.Attr)
default:
built.WriteByte('<')
built.WriteString(token.Data)
for _, attr := range token.Attr {
if attributeIsAllowed(token.DataAtom, attr) {
writeAttribute(&built, attr.Key, attr.Val)
}
}
}
built.WriteByte('>')
if !tagIsSelfClosing {
ts.push(token.DataAtom)
}
case html.EndTagToken:
tagName, _ := tz.TagName()
tag := atom.Lookup(tagName)
if tagIsAllowed(tag) && ts.pop(tag) {
built.WriteString("</")
built.Write(tagName)
built.WriteByte('>')
}
case html.TextToken:
if ts.contains(atom.Pre, atom.Code, atom.A) {
writeEscapedBytes(&built, tz.Text())
} else {
linkifyAndWriteBytes(&built, tz.Text())
}
case html.DoctypeToken, html.CommentToken:
// ignore
}
}
slices.Reverse(ts)
for _, t := range ts {
built.WriteString("</")
built.WriteString(t.String())
built.WriteByte('>')
}
return built.String(), nil
}

179
pkg/hicli/json-commands.go Normal file
View file

@ -0,0 +1,179 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) {
switch req.Command {
case "get_state":
return h.State(), nil
case "cancel":
return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) {
h.jsonRequestsLock.Lock()
cancelTarget, ok := h.jsonRequests[params.RequestID]
h.jsonRequestsLock.Unlock()
if ok {
return false, nil
}
if params.Reason == "" {
cancelTarget(nil)
} else {
cancelTarget(errors.New(params.Reason))
}
return true, nil
})
case "send_message":
return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) {
return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions)
})
case "send_event":
return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) {
return h.Send(ctx, params.RoomID, params.EventType, params.Content)
})
case "mark_read":
return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) {
return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType)
})
case "set_typing":
return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) {
return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond)
})
case "get_event":
return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) {
return h.GetEvent(ctx, params.RoomID, params.EventID)
})
case "get_events_by_rowids":
return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) {
return h.GetEventsByRowIDs(ctx, params.RowIDs)
})
case "get_room_state":
return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) {
return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch)
})
case "paginate":
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit)
})
case "paginate_server":
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
return h.PaginateServer(ctx, params.RoomID, params.Limit)
})
case "ensure_group_session_shared":
return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) {
return true, h.EnsureGroupSessionShared(ctx, params.RoomID)
})
case "login":
return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) {
return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password)
})
case "verify":
return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) {
return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey)
})
case "discover_homeserver":
return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) {
_, homeserver, err := params.UserID.Parse()
if err != nil {
return nil, err
}
return mautrix.DiscoverClientAPI(ctx, homeserver)
})
default:
return nil, fmt.Errorf("unknown command %q", req.Command)
}
}
func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) {
var input T
err = json.Unmarshal(data, &input)
if err != nil {
return
}
return fn(&input)
}
type cancelRequestParams struct {
RequestID int64 `json:"request_id"`
Reason string `json:"reason"`
}
type sendMessageParams struct {
RoomID id.RoomID `json:"room_id"`
Text string `json:"text"`
MediaPath string `json:"media_path"`
ReplyTo id.EventID `json:"reply_to"`
Mentions *event.Mentions `json:"mentions"`
}
type sendEventParams struct {
RoomID id.RoomID `json:"room_id"`
EventType event.Type `json:"type"`
Content json.RawMessage `json:"content"`
}
type markReadParams struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
ReceiptType event.ReceiptType `json:"receipt_type"`
}
type setTypingParams struct {
RoomID id.RoomID `json:"room_id"`
Timeout int `json:"timeout"`
}
type getEventParams struct {
RoomID id.RoomID `json:"room_id"`
EventID id.EventID `json:"event_id"`
}
type getEventsByRowIDsParams struct {
RowIDs []database.EventRowID `json:"row_ids"`
}
type getRoomStateParams struct {
RoomID id.RoomID `json:"room_id"`
Refetch bool `json:"refetch"`
FetchMembers bool `json:"fetch_members"`
}
type ensureGroupSessionSharedParams struct {
RoomID id.RoomID `json:"room_id"`
}
type loginParams struct {
HomeserverURL string `json:"homeserver_url"`
Username string `json:"username"`
Password string `json:"password"`
}
type verifyParams struct {
RecoveryKey string `json:"recovery_key"`
}
type discoverHomeserverParams struct {
UserID id.UserID `json:"user_id"`
}
type paginateParams struct {
RoomID id.RoomID `json:"room_id"`
MaxTimelineID database.TimelineRowID `json:"max_timeline_id"`
Limit int `json:"limit"`
}

119
pkg/hicli/json.go Normal file
View file

@ -0,0 +1,119 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"go.mau.fi/util/exerrors"
)
type JSONCommand struct {
Command string `json:"command"`
RequestID int64 `json:"request_id"`
Data json.RawMessage `json:"data"`
}
type JSONEventHandler func(*JSONCommand)
var outgoingEventCounter atomic.Int64
func (jeh JSONEventHandler) HandleEvent(evt any) {
var command string
switch evt.(type) {
case *SyncComplete:
command = "sync_complete"
case *EventsDecrypted:
command = "events_decrypted"
case *Typing:
command = "typing"
case *SendComplete:
command = "send_complete"
case *ClientState:
command = "client_state"
default:
panic(fmt.Errorf("unknown event type %T", evt))
}
data, err := json.Marshal(evt)
if err != nil {
panic(fmt.Errorf("failed to marshal event %T: %w", evt, err))
}
jeh(&JSONCommand{
Command: command,
RequestID: -outgoingEventCounter.Add(1),
Data: data,
})
}
func (h *HiClient) State() *ClientState {
state := &ClientState{}
if acc := h.Account; acc != nil {
state.IsLoggedIn = true
state.UserID = acc.UserID
state.DeviceID = acc.DeviceID
state.HomeserverURL = acc.HomeserverURL
state.IsVerified = h.Verified
}
return state
}
func (h *HiClient) dispatchCurrentState() {
h.EventHandler(h.State())
}
func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand {
if req.Command == "ping" {
return &JSONCommand{
Command: "pong",
RequestID: req.RequestID,
}
}
log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger()
ctx, cancel := context.WithCancelCause(ctx)
defer func() {
cancel(nil)
h.jsonRequestsLock.Lock()
delete(h.jsonRequests, req.RequestID)
h.jsonRequestsLock.Unlock()
}()
ctx = log.WithContext(ctx)
h.jsonRequestsLock.Lock()
h.jsonRequests[req.RequestID] = cancel
h.jsonRequestsLock.Unlock()
resp, err := h.handleJSONCommand(ctx, req)
if err != nil {
if errors.Is(err, context.Canceled) {
causeErr := context.Cause(ctx)
if causeErr != ctx.Err() {
err = fmt.Errorf("%w: %w", err, causeErr)
}
}
return &JSONCommand{
Command: "error",
RequestID: req.RequestID,
Data: exerrors.Must(json.Marshal(err.Error())),
}
}
var respData json.RawMessage
respData, err = json.Marshal(resp)
if err != nil {
return &JSONCommand{
Command: "error",
RequestID: req.RequestID,
Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))),
}
}
return &JSONCommand{
Command: "response",
RequestID: req.RequestID,
Data: respData,
}
}

88
pkg/hicli/login.go Normal file
View file

@ -0,0 +1,88 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"net/url"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
var InitialDeviceDisplayName = "mautrix hiclient"
func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error {
var err error
h.Client.HomeserverURL, err = url.Parse(homeserverURL)
if err != nil {
return err
}
return h.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: username,
},
Password: password,
InitialDeviceDisplayName: InitialDeviceDisplayName,
})
}
func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error {
err := h.CheckServerVersions(ctx)
if err != nil {
return err
}
req.StoreCredentials = true
req.StoreHomeserverURL = true
resp, err := h.Client.Login(ctx, req)
if err != nil {
return err
}
defer h.dispatchCurrentState()
h.Account = &database.Account{
UserID: resp.UserID,
DeviceID: resp.DeviceID,
AccessToken: resp.AccessToken,
HomeserverURL: h.Client.HomeserverURL.String(),
}
h.CryptoStore.AccountID = resp.UserID.String()
h.CryptoStore.DeviceID = resp.DeviceID
err = h.DB.Account.Put(ctx, h.Account)
if err != nil {
return err
}
err = h.Crypto.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm machine: %w", err)
}
err = h.Crypto.ShareKeys(ctx, 0)
if err != nil {
return err
}
_, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true)
if err != nil {
return fmt.Errorf("failed to fetch own devices: %w", err)
}
return nil
}
func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error {
err := h.LoginPassword(ctx, homeserverURL, username, password)
if err != nil {
return err
}
err = h.VerifyWithRecoveryKey(ctx, recoveryKey)
if err != nil {
return err
}
return nil
}

245
pkg/hicli/paginate.go Normal file
View file

@ -0,0 +1,245 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
)
var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) {
events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...)
if err != nil {
return nil, err
} else if len(events) == 0 {
return events, nil
}
firstRoomID := events[0].RoomID
allInSameRoom := true
for _, evt := range events {
h.ReprocessExistingEvent(ctx, evt)
if evt.RoomID != firstRoomID {
allInSameRoom = false
break
}
}
if allInSameRoom {
err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events)
if err != nil {
return events, fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events)
if err != nil {
return events, fmt.Errorf("failed to fill reaction counts: %w", err)
}
} else {
// TODO slow path where events are collected and filling is done one room at a time?
}
return events, nil
}
func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from database: %w", err)
} else if evt != nil {
h.ReprocessExistingEvent(ctx, evt)
return evt, nil
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
return nil, fmt.Errorf("failed to get event from server: %w", err)
} else {
return h.processEvent(ctx, serverEvt, nil, nil, false)
}
}
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
var evts []*event.Event
if refetch {
resp, err := h.Client.StateAsArray(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to refetch state: %w", err)
}
evts = resp
} else if fetchMembers {
resp, err := h.Client.Members(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to fetch members: %w", err)
}
evts = resp.Chunk
}
if evts != nil {
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room from database: %w", err)
}
updatedRoom := &database.Room{
ID: room.ID,
HasMemberList: true,
}
entries := make([]*database.CurrentStateEntry, len(evts))
for i, evt := range evts {
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
if err != nil {
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
}
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
}
if evt.Type == event.StateMember {
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
} else {
processImportantEvent(ctx, evt, room, updatedRoom)
}
}
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
if err != nil {
return err
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
return nil
})
if err != nil {
return nil, err
}
}
return h.DB.CurrentState.GetAll(ctx, roomID)
}
type PaginationResponse struct {
Events []*database.Event `json:"events"`
HasMore bool `json:"has_more"`
}
func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
if err != nil {
return nil, err
} else if len(evts) > 0 {
for _, evt := range evts {
h.ReprocessExistingEvent(ctx, evt)
}
return &PaginationResponse{Events: evts, HasMore: true}, nil
} else {
return h.PaginateServer(ctx, roomID, limit)
}
}
func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
ctx, cancel := context.WithCancelCause(ctx)
h.paginationInterrupterLock.Lock()
if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
h.paginationInterrupterLock.Unlock()
return nil, ErrPaginationAlreadyInProgress
}
h.paginationInterrupter[roomID] = cancel
h.paginationInterrupterLock.Unlock()
defer func() {
h.paginationInterrupterLock.Lock()
delete(h.paginationInterrupter, roomID)
h.paginationInterrupterLock.Unlock()
}()
room, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room from database: %w", err)
} else if room.PrevBatch == database.PrevBatchPaginationComplete {
return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
}
resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
if err != nil {
return nil, fmt.Errorf("failed to get messages from server: %w", err)
}
events := make([]*database.Event, len(resp.Chunk))
if resp.End == "" {
resp.End = database.PrevBatchPaginationComplete
}
if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 {
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return nil, fmt.Errorf("failed to set prev_batch: %w", err)
}
return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil
}
wakeupSessionRequests := false
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
if err = ctx.Err(); err != nil {
return err
}
eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
iOffset := 0
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
if err != nil {
return err
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
return fmt.Errorf("failed to check if event exists in timeline: %w", err)
} else if exists {
zerolog.Ctx(ctx).Warn().
Int64("row_id", int64(dbEvt.RowID)).
Str("event_id", dbEvt.ID.String()).
Msg("Event already exists in timeline, skipping")
iOffset++
continue
}
events[i-iOffset] = dbEvt
eventRowIDs[i-iOffset] = events[i-iOffset].RowID
}
if iOffset >= len(events) {
events = events[:0]
return nil
}
events = events[:len(events)-iOffset]
eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
wakeupSessionRequests = len(decryptionQueue) > 0
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
if err != nil {
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
}
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
if err != nil {
return fmt.Errorf("failed to set prev_batch: %w", err)
}
var tuples []database.TimelineRowTuple
tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
if err != nil {
return fmt.Errorf("failed to prepend events to timeline: %w", err)
}
for i, evt := range events {
evt.TimelineRowID = tuples[i].Timeline
}
return nil
})
if err == nil && wakeupSessionRequests {
h.WakeupRequestQueue()
}
return &PaginationResponse{Events: events, HasMore: true}, err
}

80
pkg/hicli/pushrules.go Normal file
View file

@ -0,0 +1,80 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type pushRoom struct {
ctx context.Context
roomID id.RoomID
h *HiClient
ll *mautrix.LazyLoadSummary
}
func (p *pushRoom) GetOwnDisplayname() string {
// TODO implement
return ""
}
func (p *pushRoom) GetMemberCount() int {
if p.ll == nil {
room, err := p.h.DB.Room.Get(p.ctx, p.roomID)
if err != nil {
zerolog.Ctx(p.ctx).Err(err).
Stringer("room_id", p.roomID).
Msg("Failed to get room by ID in push rule evaluator")
} else if room != nil {
p.ll = room.LazyLoadSummary
}
}
if p.ll != nil && p.ll.JoinedMemberCount != nil {
return *p.ll.JoinedMemberCount
}
// TODO query db?
return 0
}
func (p *pushRoom) GetEvent(id id.EventID) *event.Event {
evt, err := p.h.DB.Event.GetByID(p.ctx, id)
if err != nil {
zerolog.Ctx(p.ctx).Err(err).
Stringer("event_id", id).
Msg("Failed to get event by ID in push rule evaluator")
}
return evt.AsRawMautrix()
}
var _ pushrules.EventfulRoom = (*pushRoom)(nil)
func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType {
should := h.PushRules.Load().GetMatchingRule(&pushRoom{
ctx: ctx,
roomID: evt.RoomID,
h: h,
ll: llSummary,
}, evt).GetActions().Should()
if should.Notify {
baseType |= database.UnreadTypeNotify
}
if should.Highlight {
baseType |= database.UnreadTypeHighlight
}
if should.PlaySound {
baseType |= database.UnreadTypeSound
}
return baseType
}

287
pkg/hicli/send.go Normal file
View file

@ -0,0 +1,287 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/yuin/goldmark"
"go.mau.fi/util/jsontime"
"go.mau.fi/util/ptr"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli/database"
"go.mau.fi/gomuks/pkg/rainbow"
)
var (
rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension))
)
func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) {
var content event.MessageEventContent
if strings.HasPrefix(text, "/rainbow ") {
text = strings.TrimPrefix(text, "/rainbow ")
content = format.RenderMarkdownCustom(text, rainbowWithHTML)
content.FormattedBody = rainbow.ApplyColor(content.FormattedBody)
} else if strings.HasPrefix(text, "/plain ") {
text = strings.TrimPrefix(text, "/plain ")
content = format.RenderMarkdown(text, false, false)
} else if strings.HasPrefix(text, "/html ") {
text = strings.TrimPrefix(text, "/html ")
content = format.RenderMarkdown(text, false, true)
} else {
content = format.RenderMarkdown(text, true, false)
}
if mentions != nil {
content.Mentions.Room = mentions.Room
for _, userID := range mentions.UserIDs {
if userID != h.Account.UserID {
content.Mentions.Add(userID)
}
}
}
if replyTo != "" {
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo)
}
return h.Send(ctx, roomID, event.EventMessage, &content)
}
func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error {
content := &mautrix.ReqSetReadMarkers{
FullyRead: eventID,
}
if receiptType == event.ReceiptTypeRead {
content.Read = eventID
} else if receiptType == event.ReceiptTypeReadPrivate {
content.ReadPrivate = eventID
} else {
return fmt.Errorf("invalid receipt type: %v", receiptType)
}
err := h.Client.SetReadMarkers(ctx, roomID, content)
if err != nil {
return fmt.Errorf("failed to mark event as read: %w", err)
}
return nil
}
func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error {
_, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout)
return err
}
func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) {
roomMeta, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room metadata: %w", err)
} else if roomMeta == nil {
return nil, fmt.Errorf("unknown room")
}
var decryptedType event.Type
var decryptedContent json.RawMessage
var megolmSessionID id.SessionID
if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction {
decryptedType = evtType
decryptedContent, err = json.Marshal(content)
if err != nil {
return nil, fmt.Errorf("failed to marshal event content: %w", err)
}
encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content)
if err != nil {
return nil, fmt.Errorf("failed to encrypt event: %w", err)
}
megolmSessionID = encryptedContent.SessionID
content = encryptedContent
evtType = event.EventEncrypted
}
mainContent, err := json.Marshal(content)
if err != nil {
return nil, fmt.Errorf("failed to marshal event content: %w", err)
}
txnID := "hicli-" + h.Client.TxnID()
relatesTo, relationType := database.GetRelatesToFromBytes(mainContent)
dbEvt := &database.Event{
RoomID: roomID,
ID: id.EventID(fmt.Sprintf("~%s", txnID)),
Sender: h.Account.UserID,
Type: evtType.Type,
Timestamp: jsontime.UnixMilliNow(),
Content: mainContent,
Decrypted: decryptedContent,
DecryptedType: decryptedType.Type,
Unsigned: []byte("{}"),
TransactionID: txnID,
RelatesTo: relatesTo,
RelationType: relationType,
MegolmSessionID: megolmSessionID,
DecryptionError: "",
SendError: "not sent",
Reactions: map[string]int{},
LastEditRowID: ptr.Ptr(database.EventRowID(0)),
}
_, err = h.DB.Event.Insert(ctx, dbEvt)
if err != nil {
return nil, fmt.Errorf("failed to insert event into database: %w", err)
}
ctx = context.WithoutCancel(ctx)
go func() {
err := h.SetTyping(ctx, roomID, 0)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message")
}
}()
go func() {
var err error
defer func() {
h.EventHandler(&SendComplete{
Event: dbEvt,
Error: err,
})
}()
var resp *mautrix.RespSendEvent
resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{
Timestamp: dbEvt.Timestamp.UnixMilli(),
TransactionID: txnID,
DontEncrypt: true,
})
if err != nil {
dbEvt.SendError = err.Error()
err = fmt.Errorf("failed to send event: %w", err)
err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError)
if err2 != nil {
zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err).
Msg("Failed to update send error in database after sending failed")
}
return
}
dbEvt.ID = resp.EventID
err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID)
if err != nil {
err = fmt.Errorf("failed to update event ID in database: %w", err)
}
}()
return dbEvt, nil
}
func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
h.encryptLock.Lock()
defer h.encryptLock.Unlock()
encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content)
if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) {
if err = h.shareGroupSession(ctx, room); err != nil {
err = fmt.Errorf("failed to share group session: %w", err)
} else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil {
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return
}
func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error {
h.encryptLock.Lock()
defer h.encryptLock.Unlock()
if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return nil
} else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil {
return fmt.Errorf("failed to get room metadata: %w", err)
} else if roomMeta == nil {
return fmt.Errorf("unknown room")
} else {
return h.shareGroupSession(ctx, roomMeta)
}
}
func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
if room.HasMemberList {
return nil
}
resp, err := h.Client.Members(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to get room member list: %w", err)
}
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
for i, evt := range resp.Chunk {
dbEvt, err := h.processEvent(ctx, evt, nil, nil, true)
if err != nil {
return err
}
entries[i] = &database.CurrentStateEntry{
EventType: evt.Type,
StateKey: *evt.StateKey,
EventRowID: dbEvt.RowID,
Membership: event.Membership(evt.Content.Raw["membership"].(string)),
}
}
err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
if err != nil {
return err
}
return h.DB.Room.Upsert(ctx, &database.Room{
ID: room.ID,
HasMemberList: true,
})
})
if err != nil {
return fmt.Errorf("failed to process room member list: %w", err)
}
return nil
}
func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error {
err := h.loadMembers(ctx, room)
if err != nil {
return err
}
shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID)
var users []id.UserID
if shareToInvited {
users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID)
} else {
users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID)
}
if err != nil {
return fmt.Errorf("failed to get room member list: %w", err)
} else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
return nil
}
func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool {
historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "")
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event")
return false
}
mautrixEvt := historyVisibility.AsRawMautrix()
err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event")
return false
}
hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent)
if !ok {
zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event")
return false
}
return hv.HistoryVisibility == event.HistoryVisibilityInvited ||
hv.HistoryVisibility == event.HistoryVisibilityShared ||
hv.HistoryVisibility == event.HistoryVisibilityWorldReadable
}

850
pkg/hicli/sync.go Normal file
View file

@ -0,0 +1,850 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exzerolog"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
"go.mau.fi/gomuks/pkg/hicli/database"
)
type syncContext struct {
shouldWakeupRequestQueue bool
evt *SyncComplete
}
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
log := zerolog.Ctx(ctx)
postponedToDevices := resp.ToDevice.Events[:0]
for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
log.Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("sender", evt.Sender).
Msg("Failed to parse to-device event, skipping")
continue
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
h.Crypto.HandleEncryptedEvent(ctx, evt)
case *event.RoomKeyWithheldEventContent:
h.Crypto.HandleRoomKeyWithheld(ctx, content)
default:
postponedToDevices = append(postponedToDevices, evt)
}
}
resp.ToDevice.Events = postponedToDevices
return nil
}
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
go h.asyncPostProcessSyncResponse(ctx, resp, since)
syncCtx := ctx.Value(syncContextKey).(*syncContext)
if syncCtx.shouldWakeupRequestQueue {
h.WakeupRequestQueue()
}
h.firstSyncReceived = true
if !syncCtx.evt.IsEmpty() {
h.EventHandler(syncCtx.evt)
}
}
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
for _, evt := range resp.ToDevice.Events {
switch content := evt.Content.Parsed.(type) {
case *event.SecretRequestEventContent:
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
case *event.RoomKeyRequestEventContent:
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
}
}
}
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
if len(resp.DeviceLists.Changed) > 0 {
zerolog.Ctx(ctx).Debug().
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
Msg("Marking changed device lists for tracked users as outdated")
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
if err != nil {
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
}
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
for _, evt := range resp.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
if evt.Type == event.AccountDataPushRules {
err = evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
} else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
h.PushRules.Store(pushRules.Ruleset)
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
}
}
}
for roomID, room := range resp.Rooms.Join {
err := h.processSyncJoinedRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
}
}
for roomID, room := range resp.Rooms.Leave {
err := h.processSyncLeftRoom(ctx, roomID, room)
if err != nil {
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
}
}
h.Account.NextBatch = resp.NextBatch
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
if err != nil {
return fmt.Errorf("failed to save next_batch: %w", err)
}
return nil
}
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
receiptList := make([]*database.Receipt, 0)
var newOwnReceipts []id.EventID
for eventID, receipts := range *content {
for receiptType, users := range receipts {
for userID, receiptInfo := range users {
if userID == h.Account.UserID {
newOwnReceipts = append(newOwnReceipts, eventID)
}
receiptList = append(receiptList, &database.Receipt{
UserID: userID,
ReceiptType: receiptType,
ThreadID: receiptInfo.ThreadID,
EventID: eventID,
Timestamp: jsontime.UM(receiptInfo.Timestamp),
})
}
}
}
return receiptList, newOwnReceipts
}
type receiptsToSave struct {
roomID id.RoomID
receipts []*database.Receipt
}
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
err = h.DB.Room.CreateRow(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to ensure room row exists: %w", err)
}
existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()}
}
for _, evt := range room.AccountData.Events {
evt.Type.Class = event.AccountDataEventType
evt.RoomID = roomID
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
if err != nil {
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
}
}
var receipts []receiptsToSave
var newOwnReceipts []id.EventID
for _, evt := range room.Ephemeral.Events {
evt.Type.Class = event.EphemeralEventType
err = evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
continue
}
switch evt.Type {
case event.EphemeralEventReceipt:
var receiptsList []*database.Receipt
receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt())
receipts = append(receipts, receiptsToSave{roomID, receiptsList})
case event.EphemeralEventTyping:
go h.EventHandler(&Typing{
RoomID: roomID,
TypingEventContent: *evt.Content.AsTyping(),
})
}
}
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications)
if err != nil {
return err
}
for _, rs := range receipts {
err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...)
if err != nil {
return fmt.Errorf("failed to save receipts: %w", err)
}
}
return nil
}
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get room data: %w", err)
} else if existingRoomData == nil {
return nil
}
// TODO delete room instead of processing?
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
}
func isDecryptionErrorRetryable(err error) bool {
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
}
func removeReplyFallback(evt *event.Event) []byte {
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
return nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if ok && content.RelatesTo.GetReplyTo() != "" {
prevFormattedBody := content.FormattedBody
content.RemoveReplyFallback()
if content.FormattedBody != prevFormattedBody {
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
if err == nil && err2 == nil {
return bytes
}
}
}
return nil
}
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
return nil, nil, "", err
}
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
if err != nil {
return nil, nil, "", err
}
withoutFallback := removeReplyFallback(decrypted)
if withoutFallback != nil {
return decrypted, withoutFallback, decrypted.Type.Type, nil
}
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
}
func (h *HiClient) addMediaCache(
ctx context.Context,
eventRowID database.EventRowID,
uri id.ContentURIString,
file *event.EncryptedFileInfo,
info *event.FileInfo,
fileName string,
) {
parsedMXC := uri.ParseOrIgnore()
if !parsedMXC.IsValid() {
return
}
cm := &database.CachedMedia{
MXC: parsedMXC,
EventRowID: eventRowID,
FileName: fileName,
}
if file != nil {
cm.EncFile = &file.EncryptedFile
}
if info != nil {
cm.MimeType = info.MimeType
}
err := h.DB.CachedMedia.Put(ctx, cm)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("mxc", parsedMXC).
Int64("event_rowid", int64(eventRowID)).
Msg("Failed to add cached media entry")
}
}
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
switch evt.Type {
case event.EventMessage, event.EventSticker:
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return
}
if content.File != nil {
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
} else if content.URL != "" {
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
}
if content.GetInfo().ThumbnailFile != nil {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
} else if content.GetInfo().ThumbnailURL != "" {
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
}
case event.StateRoomAvatar:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
case event.StateMember:
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
if !ok {
return
}
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
}
}
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
return nil
}
_ = evt.Content.ParseRaw(evt.Type)
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok {
return nil
}
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
content = content.NewContent
}
if content != nil {
var sanitizedHTML string
if content.Format == event.FormatHTML {
sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody)
} else {
var builder strings.Builder
linkifyAndWriteBytes(&builder, []byte(content.Body))
sanitizedHTML = builder.String()
}
return &database.LocalContent{SanitizedHTML: sanitizedHTML, HTMLVersion: CurrentHTMLSanitizerVersion}
}
return nil
}
const CurrentHTMLSanitizerVersion = 1
func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) {
if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
return
}
evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
err := h.DB.Event.UpdateLocalContent(ctx, evt)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("event_id", evt.ID).
Msg("Failed to update local content")
}
}
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) {
if dbEvt.RowID != 0 {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
if evt.Sender != h.Account.UserID {
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
}
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt)
}
func (h *HiClient) processEvent(
ctx context.Context,
evt *event.Event,
llSummary *mautrix.LazyLoadSummary,
decryptionQueue map[id.SessionID]*database.SessionRequest,
checkDB bool,
) (*database.Event, error) {
if checkDB {
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
if err != nil {
return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
} else if dbEvt != nil {
return dbEvt, nil
}
}
dbEvt := database.MautrixToEvent(evt)
contentWithoutFallback := removeReplyFallback(evt)
if contentWithoutFallback != nil {
dbEvt.Content = contentWithoutFallback
}
var decryptionErr error
var decryptedMautrixEvt *event.Event
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
if decryptionErr != nil {
dbEvt.DecryptionError = decryptionErr.Error()
}
} else if evt.Type == event.EventRedaction {
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
var err error
evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
if err != nil {
return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
}
} else if evt.Redacts == "" {
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
}
}
if decryptedMautrixEvt != nil {
h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
} else {
h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
}
_, err := h.DB.Event.Upsert(ctx, dbEvt)
if err != nil {
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
}
if decryptedMautrixEvt != nil {
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
} else {
h.cacheMedia(ctx, evt, dbEvt.RowID)
}
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
if !ok {
req = &database.SessionRequest{
RoomID: evt.RoomID,
SessionID: dbEvt.MegolmSessionID,
Sender: evt.Sender,
}
}
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
req.MinIndex = min(uint32(minIndex), req.MinIndex)
if decryptionQueue != nil {
decryptionQueue[dbEvt.MegolmSessionID] = req
} else {
err = h.DB.SessionRequest.Put(ctx, req)
if err != nil {
zerolog.Ctx(ctx).Err(err).
Stringer("session_id", dbEvt.MegolmSessionID).
Msg("Failed to save session request")
} else {
h.WakeupRequestQueue()
}
}
}
return dbEvt, err
}
func (h *HiClient) processStateAndTimeline(
ctx context.Context,
room *database.Room,
state *mautrix.SyncEventsList,
timeline *mautrix.SyncTimeline,
summary *mautrix.LazyLoadSummary,
newOwnReceipts []id.EventID,
serverNotificationCounts *mautrix.UnreadNotificationCounts,
) error {
updatedRoom := &database.Room{
ID: room.ID,
SortingTimestamp: room.SortingTimestamp,
NameQuality: room.NameQuality,
UnreadHighlights: room.UnreadHighlights,
UnreadNotifications: room.UnreadNotifications,
UnreadMessages: room.UnreadMessages,
}
if serverNotificationCounts != nil {
updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount
updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount
}
heroesChanged := false
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
summary = room.LazyLoadSummary
} else if room.LazyLoadSummary == nil ||
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
updatedRoom.LazyLoadSummary = summary
heroesChanged = true
}
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
newNotifications := make([]SyncNotification, 0)
recalculatePreviewEvent := false
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
if rowID != 0 {
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
} else {
dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
}
if err != nil {
return nil, fmt.Errorf("failed to get redaction target: %w", err)
} else if dbEvt == nil {
return nil, nil
}
allNewEvents = append(allNewEvents, dbEvt)
return dbEvt, nil
}
processRedaction := func(evt *event.Event) error {
dbEvt, err := addOldEvent(0, evt.Redacts)
if err != nil {
return fmt.Errorf("failed to get redaction target: %w", err)
}
if dbEvt == nil {
return nil
}
if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
_, err = addOldEvent(0, dbEvt.RelatesTo)
if err != nil {
return fmt.Errorf("failed to get relation target of redaction target: %w", err)
}
}
if updatedRoom.PreviewEventRowID == dbEvt.RowID {
updatedRoom.PreviewEventRowID = 0
recalculatePreviewEvent = true
}
return nil
}
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
evt.RoomID = room.ID
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false)
if err != nil {
return -1, err
}
if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
newNotifications = append(newNotifications, SyncNotification{
RowID: dbEvt.RowID,
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
})
}
if isTimeline {
if dbEvt.CanUseForPreview() {
updatedRoom.PreviewEventRowID = dbEvt.RowID
recalculatePreviewEvent = false
}
updatedRoom.BumpSortingTimestamp(dbEvt)
}
if evt.StateKey != nil {
var membership event.Membership
if evt.Type == event.StateMember {
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
heroesChanged = true
}
} else if evt.Type == event.StateElementFunctionalMembers {
heroesChanged = true
}
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
if err != nil {
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
}
processImportantEvent(ctx, evt, room, updatedRoom)
}
allNewEvents = append(allNewEvents, dbEvt)
if evt.Type == event.EventRedaction && evt.Redacts != "" {
err = processRedaction(evt)
if err != nil {
return -1, fmt.Errorf("failed to process redaction: %w", err)
}
} else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
_, err = addOldEvent(0, dbEvt.RelatesTo)
if err != nil {
return -1, fmt.Errorf("failed to get relation target of event: %w", err)
}
}
return dbEvt.RowID, nil
}
changedState := make(map[event.Type]map[string]database.EventRowID)
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
if _, ok := changedState[evtType]; !ok {
changedState[evtType] = make(map[string]database.EventRowID)
}
changedState[evtType][stateKey] = rowID
}
for _, evt := range state.Events {
evt.Type.Class = event.StateEventType
rowID, err := processNewEvent(evt, false, false)
if err != nil {
return err
}
setNewState(evt.Type, *evt.StateKey, rowID)
}
var timelineRowTuples []database.TimelineRowTuple
var err error
if len(timeline.Events) > 0 {
timelineIDs := make([]database.EventRowID, len(timeline.Events))
readUpToIndex := -1
for i := len(timeline.Events) - 1; i >= 0; i-- {
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
readUpToIndex = i
break
}
}
for i, evt := range timeline.Events {
if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
evt.Type.Class = event.MessageEventType
}
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
if err != nil {
return err
}
if evt.StateKey != nil {
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
}
}
for _, entry := range decryptionQueue {
err = h.DB.SessionRequest.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
}
}
if len(decryptionQueue) > 0 {
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
}
if timeline.Limited {
err = h.DB.Timeline.Clear(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to clear old timeline: %w", err)
}
updatedRoom.PrevBatch = timeline.PrevBatch
h.paginationInterrupterLock.Lock()
if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
interrupt(ErrTimelineReset)
}
h.paginationInterrupterLock.Unlock()
}
timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
if err != nil {
return fmt.Errorf("failed to append timeline: %w", err)
}
} else {
timelineRowTuples = make([]database.TimelineRowTuple, 0)
}
if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
if err != nil {
return fmt.Errorf("failed to recalculate preview event: %w", err)
}
_, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
if err != nil {
return fmt.Errorf("failed to get preview event: %w", err)
}
}
// Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
if err != nil {
return fmt.Errorf("failed to calculate room name: %w", err)
}
updatedRoom.Name = &name
updatedRoom.NameQuality = database.NameQualityParticipants
if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
updatedRoom.Avatar = &dmAvatarURL
}
}
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
updatedRoom.PrevBatch = timeline.PrevBatch
}
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
if roomChanged {
err = h.DB.Room.Upsert(ctx, updatedRoom)
if err != nil {
return fmt.Errorf("failed to save room data: %w", err)
}
}
if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
Meta: room,
Timeline: timelineRowTuples,
State: changedState,
Reset: timeline.Limited,
Events: allNewEvents,
Notifications: newNotifications,
}
}
return nil
}
func joinMemberNames(names []string, totalCount int) string {
if len(names) == 1 {
return names[0]
} else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
} else {
return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
}
}
func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
var primaryAvatarURL id.ContentURI
if summary == nil || len(summary.Heroes) == 0 {
return "Empty room", primaryAvatarURL, nil
}
var functionalMembers []id.UserID
functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
if err != nil {
return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
} else if functionalMembersEvt != nil {
mautrixEvt := functionalMembersEvt.AsRawMautrix()
_ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
if ok {
functionalMembers = content.ServiceMembers
}
}
var members, leftMembers []string
var memberCount int
if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
memberCount = *summary.JoinedMemberCount
} else if summary.InvitedMemberCount != nil {
memberCount = *summary.InvitedMemberCount
}
for _, hero := range summary.Heroes {
if slices.Contains(functionalMembers, hero) {
memberCount--
continue
} else if len(members) >= 5 {
break
}
heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
if err != nil {
return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
} else if heroEvt == nil {
leftMembers = append(leftMembers, hero.String())
continue
}
membership := gjson.GetBytes(heroEvt.Content, "membership").Str
name := gjson.GetBytes(heroEvt.Content, "displayname").Str
if name == "" {
name = hero.String()
}
avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
if avatarURL != "" {
primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
}
if membership == "join" || membership == "invite" {
members = append(members, name)
} else {
leftMembers = append(leftMembers, name)
}
}
if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
primaryAvatarURL = id.ContentURI{}
}
if len(members) > 0 {
return joinMemberNames(members, memberCount), primaryAvatarURL, nil
} else if len(leftMembers) > 0 {
return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
} else {
return "Empty room", primaryAvatarURL, nil
}
}
func intPtrEqual(a, b *int) bool {
if a == nil || b == nil {
return a == b
}
return *a == *b
}
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
if evt.StateKey == nil {
return
}
switch evt.Type {
case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
if *evt.StateKey != "" {
return
}
default:
return
}
err := evt.Content.ParseRaw(evt.Type)
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_type", &evt.Type).
Stringer("event_id", evt.ID).
Msg("Failed to parse state event, skipping")
return
}
switch evt.Type {
case event.StateCreate:
updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
case event.StateEncryption:
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
updatedRoom.EncryptionEvent = newEncryption
}
case event.StateRoomName:
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
if ok {
updatedRoom.Name = &content.Name
updatedRoom.NameQuality = database.NameQualityExplicit
if content.Name == "" {
if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
} else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
} else {
updatedRoom.NameQuality = database.NameQualityNil
}
}
}
case event.StateCanonicalAlias:
content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
if ok {
updatedRoom.CanonicalAlias = &content.Alias
if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
updatedRoom.Name = (*string)(&content.Alias)
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
if content.Alias == "" {
updatedRoom.NameQuality = database.NameQualityNil
}
}
}
case event.StateRoomAvatar:
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
if ok {
url, _ := content.URL.Parse()
updatedRoom.Avatar = &url
updatedRoom.ExplicitAvatar = true
}
case event.StateTopic:
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
if ok {
updatedRoom.Topic = &content.Topic
}
}
return
}

96
pkg/hicli/syncwrap.go Normal file
View file

@ -0,0 +1,96 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"fmt"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
type hiSyncer HiClient
var _ mautrix.Syncer = (*hiSyncer)(nil)
type contextKey int
const (
syncContextKey contextKey = iota
)
func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
c := (*HiClient)(h)
ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}})
err := c.preProcessSyncResponse(ctx, resp, since)
if err != nil {
return err
}
err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
return c.processSyncResponse(ctx, resp, since)
})
if err != nil {
return err
}
c.postProcessSyncResponse(ctx, resp, since)
return nil
}
func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
(*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second")
return 1 * time.Second, nil
}
func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
if !h.Verified {
return &mautrix.Filter{
Presence: mautrix.FilterPart{
NotRooms: []id.RoomID{"*"},
},
Room: mautrix.RoomFilter{
NotRooms: []id.RoomID{"*"},
},
}
}
return &mautrix.Filter{
Presence: mautrix.FilterPart{
NotRooms: []id.RoomID{"*"},
},
Room: mautrix.RoomFilter{
State: mautrix.FilterPart{
LazyLoadMembers: true,
},
Timeline: mautrix.FilterPart{
Limit: 100,
LazyLoadMembers: true,
},
},
}
}
type hiStore HiClient
var _ mautrix.SyncStore = (*hiStore)(nil)
// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing
func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil }
func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil }
func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
// This is intentionally a no-op: we don't want to save the next batch before processing the sync
return nil
}
func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) {
if h.Account.UserID != userID {
return "", fmt.Errorf("mismatching user ID")
}
return h.Account.NextBatch, nil
}

161
pkg/hicli/verify.go Normal file
View file

@ -0,0 +1,161 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package hicli
import (
"context"
"encoding/base64"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) {
keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx)
if keys == nil {
return false, fmt.Errorf("own cross-signing keys not found")
}
isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey)
if err != nil {
return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
}
return isVerified, nil
}
func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error {
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
if err != nil {
return fmt.Errorf("failed to get key backup latest version: %w", err)
}
h.KeyBackupVersion = latestVersion.Version
data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey)
if err != nil {
return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err)
}
key, err := backup.MegolmBackupKeyFromBytes(data)
if err != nil {
return fmt.Errorf("failed to parse megolm backup key: %w", err)
}
err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes()))
if err != nil {
return fmt.Errorf("failed to store megolm backup key: %w", err)
}
h.KeyBackupKey = key
return nil
}
func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) {
secretData, err := h.CryptoStore.GetSecret(ctx, secret)
if err != nil {
return nil, fmt.Errorf("failed to get secret %s: %w", secret, err)
}
data, err := base64.StdEncoding.DecodeString(secretData)
if err != nil {
return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err)
}
return data, nil
}
func (h *HiClient) loadPrivateKeys(ctx context.Context) error {
zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys")
masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster)
if err != nil {
return fmt.Errorf("failed to get master key: %w", err)
}
selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning)
if err != nil {
return fmt.Errorf("failed to get self-signing key: %w", err)
}
userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning)
if err != nil {
return fmt.Errorf("failed to get user signing key: %w", err)
}
err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{
MasterKey: masterKeySeed,
SelfSigningKey: selfSigningKeySeed,
UserSigningKey: userSigningKeySeed,
})
if err != nil {
return fmt.Errorf("failed to import cross-signing private keys: %w", err)
}
zerolog.Ctx(ctx).Debug().Msg("Loading key backup key")
keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1)
if err != nil {
return fmt.Errorf("failed to get megolm backup key: %w", err)
}
h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey)
if err != nil {
return fmt.Errorf("failed to parse megolm backup key: %w", err)
}
zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version")
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
if err != nil {
return fmt.Errorf("failed to get key backup latest version: %w", err)
}
h.KeyBackupVersion = latestVersion.Version
zerolog.Ctx(ctx).Debug().Msg("Secrets loaded")
return nil
}
func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
keys := h.Crypto.CrossSigningKeys
err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed()))
if err != nil {
return err
}
err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed()))
if err != nil {
return err
}
err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed()))
if err != nil {
return err
}
return nil
}
func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error {
defer h.dispatchCurrentState()
keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
if err != nil {
return fmt.Errorf("failed to get default SSSS key data: %w", err)
}
key, err := keyData.VerifyRecoveryKey(keyID, code)
if err != nil {
return err
}
err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key)
if err != nil {
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
}
err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity())
if err != nil {
return fmt.Errorf("failed to sign own device: %w", err)
}
err = h.Crypto.SignOwnMasterKey(ctx)
if err != nil {
return fmt.Errorf("failed to sign own master key: %w", err)
}
err = h.storeCrossSigningPrivateKeys(ctx)
if err != nil {
return fmt.Errorf("failed to store cross-signing private keys: %w", err)
}
err = h.fetchKeyBackupKey(ctx, key)
if err != nil {
return fmt.Errorf("failed to fetch key backup key: %w", err)
}
h.Verified = true
if !h.IsSyncing() {
go h.Sync()
}
return nil
}

120
pkg/rainbow/goldmark.go Normal file
View file

@ -0,0 +1,120 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package rainbow
import (
"fmt"
"unicode"
"github.com/rivo/uniseg"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/renderer"
"github.com/yuin/goldmark/renderer/html"
"github.com/yuin/goldmark/util"
"go.mau.fi/util/random"
)
// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer.
var Extension = &extRainbow{}
type extRainbow struct{}
type rainbowRenderer struct {
HardWraps bool
ColorID string
}
var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)}
func (er *extRainbow) Extend(m goldmark.Markdown) {
m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0)))
}
func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
reg.Register(ast.KindText, rb.renderText)
reg.Register(ast.KindString, rb.renderString)
}
type rainbowBufWriter struct {
util.BufWriter
ColorID string
}
func (rbw rainbowBufWriter) WriteString(s string) (int, error) {
i := 0
graphemes := uniseg.NewGraphemes(s)
for graphemes.Next() {
runes := graphemes.Runes()
if len(runes) == 1 && unicode.IsSpace(runes[0]) {
i2, err := rbw.BufWriter.WriteRune(runes[0])
i += i2
if err != nil {
return i, err
}
continue
}
i2, err := fmt.Fprintf(rbw.BufWriter, "<font color=\"%s\">%s</font>", rbw.ColorID, graphemes.Str())
i += i2
if err != nil {
return i, err
}
}
return i, nil
}
func (rbw rainbowBufWriter) Write(data []byte) (int, error) {
return rbw.WriteString(string(data))
}
func (rbw rainbowBufWriter) WriteByte(c byte) error {
_, err := rbw.WriteRune(rune(c))
return err
}
func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) {
if unicode.IsSpace(r) {
return rbw.BufWriter.WriteRune(r)
} else {
return fmt.Fprintf(rbw.BufWriter, "<font color=\"%s\">%c</font>", rbw.ColorID, r)
}
}
func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
if !entering {
return ast.WalkContinue, nil
}
n := node.(*ast.Text)
segment := n.Segment
if n.IsRaw() {
html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
} else {
html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) {
_, _ = w.WriteString("<br>\n")
} else if n.SoftLineBreak() {
_ = w.WriteByte('\n')
}
}
return ast.WalkContinue, nil
}
func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
if !entering {
return ast.WalkContinue, nil
}
n := node.(*ast.String)
if n.IsCode() {
_, _ = w.Write(n.Value)
} else {
if n.IsRaw() {
html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value)
} else {
html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value)
}
}
return ast.WalkContinue, nil
}

56
pkg/rainbow/gradient.go Normal file
View file

@ -0,0 +1,56 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package rainbow
import (
"regexp"
"strings"
"github.com/lucasb-eyer/go-colorful"
)
// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go
type GradientTable []struct {
Col colorful.Color
Pos float64
}
func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color {
for i := 0; i < len(gt)-1; i++ {
c1 := gt[i]
c2 := gt[i+1]
if c1.Pos <= t && t <= c2.Pos {
t := (t - c1.Pos) / (c2.Pos - c1.Pos)
return c1.Col.BlendHcl(c2.Col, t).Clamped()
}
}
return gt[len(gt)-1].Col
}
var Gradient = GradientTable{
{colorful.LinearRgb(1, 0, 0), 0 / 11.0},
{colorful.LinearRgb(1, 0.5, 0), 1 / 11.0},
{colorful.LinearRgb(1, 1, 0), 2 / 11.0},
{colorful.LinearRgb(0.5, 1, 0), 3 / 11.0},
{colorful.LinearRgb(0, 1, 0), 4 / 11.0},
{colorful.LinearRgb(0, 1, 0.5), 5 / 11.0},
{colorful.LinearRgb(0, 1, 1), 6 / 11.0},
{colorful.LinearRgb(0, 0.5, 1), 7 / 11.0},
{colorful.LinearRgb(0, 0, 1), 8 / 11.0},
{colorful.LinearRgb(0.5, 0, 1), 9 / 11.0},
{colorful.LinearRgb(1, 0, 1), 10 / 11.0},
{colorful.LinearRgb(1, 0, 0.5), 11 / 11.0},
}
func ApplyColor(htmlBody string) string {
count := strings.Count(htmlBody, defaultRB.ColorID)
i := -1
return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string {
i++
return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex()
})
}

View file

@ -82,11 +82,12 @@ var (
)
type tokenData struct {
Username string `json:"username"`
Expiry jsontime.Unix `json:"expiry"`
Username string `json:"username"`
Expiry jsontime.Unix `json:"expiry"`
ImageOnly bool `json:"image_only,omitempty"`
}
func (gmx *Gomuks) validateAuth(token string) bool {
func (gmx *Gomuks) validateAuth(token string, imageOnly bool) bool {
if len(token) > 500 {
return false
}
@ -110,19 +111,31 @@ func (gmx *Gomuks) validateAuth(token string) bool {
var td tokenData
err = json.Unmarshal(rawJSON, &td)
return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now())
return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now()) && td.ImageOnly == imageOnly
}
func (gmx *Gomuks) generateToken() (string, time.Time) {
expiry := time.Now().Add(7 * 24 * time.Hour)
data := exerrors.Must(json.Marshal(tokenData{
return gmx.signToken(tokenData{
Username: gmx.Config.Web.Username,
Expiry: jsontime.U(expiry),
}))
}), expiry
}
func (gmx *Gomuks) generateImageToken() string {
return gmx.signToken(tokenData{
Username: gmx.Config.Web.Username,
Expiry: jsontime.U(time.Now().Add(1 * time.Hour)),
ImageOnly: true,
})
}
func (gmx *Gomuks) signToken(td tokenData) string {
data := exerrors.Must(json.Marshal(td))
hasher := hmac.New(sha256.New, []byte(gmx.Config.Web.TokenKey))
hasher.Write(data)
checksum := hasher.Sum(nil)
return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum), expiry
return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum)
}
func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) {
@ -137,7 +150,7 @@ func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) {
func (gmx *Gomuks) Authenticate(w http.ResponseWriter, r *http.Request) {
authCookie, err := r.Cookie("gomuks_auth")
if err == nil && gmx.validateAuth(authCookie.Value) {
if err == nil && gmx.validateAuth(authCookie.Value, false) {
gmx.writeTokenCookie(w)
w.WriteHeader(http.StatusOK)
} else if username, password, ok := r.BasicAuth(); !ok {
@ -164,9 +177,23 @@ func isUserFetch(header http.Header) bool {
header.Get("Sec-Fetch-User") == "?1"
}
func isImageFetch(header http.Header) bool {
return header.Get("Sec-Fetch-Site") == "cross-site" &&
header.Get("Sec-Fetch-Mode") == "no-cors" &&
header.Get("Sec-Fetch-Dest") == "image"
}
func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Sec-Fetch-Site") != "" && r.Header.Get("Sec-Fetch-Site") != "same-origin" && !isUserFetch(r.Header) {
if strings.HasPrefix(r.URL.Path, "/media") &&
isImageFetch(r.Header) &&
gmx.validateAuth(r.URL.Query().Get("image_auth"), true) &&
r.URL.Query().Get("encrypted") == "false" {
next.ServeHTTP(w, r)
return
} else if r.Header.Get("Sec-Fetch-Site") != "" &&
r.Header.Get("Sec-Fetch-Site") != "same-origin" &&
!isUserFetch(r.Header) {
hlog.FromRequest(r).Debug().
Str("site", r.Header.Get("Sec-Fetch-Site")).
Str("dest", r.Header.Get("Sec-Fetch-Dest")).
@ -181,7 +208,7 @@ func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler {
if err != nil {
ErrMissingCookie.Write(w)
return
} else if !gmx.validateAuth(authCookie.Value) {
} else if !gmx.validateAuth(authCookie.Value, false) {
ErrInvalidCookie.Write(w)
return
}

View file

@ -29,6 +29,8 @@ function App() {
const clientState = useEventAsState(client.state)
;((window as unknown) as { client: Client }).client = client
useEffect(() => {
Notification.requestPermission()
.then(permission => console.log("Notification permission:", permission))
client.rpc.start()
return () => client.rpc.stop()
}, [client])

View file

@ -39,6 +39,8 @@ export default class Client {
this.store.applyDecrypted(ev.data)
} else if (ev.command === "send_complete") {
this.store.applySendComplete(ev.data)
} else if (ev.command === "image_auth_token") {
this.store.imageAuthToken = ev.data
}
}

View file

@ -14,9 +14,11 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import unhomoglyph from "unhomoglyph"
import { getMediaURL } from "@/api/media.ts"
import { NonNullCachedEventDispatcher } from "@/util/eventdispatcher.ts"
import { focused } from "@/util/focus.ts"
import type {
ContentURI,
ContentURI, EventRowID,
EventsDecryptedData,
MemDBEvent,
RoomID,
@ -34,6 +36,8 @@ export interface RoomListEntry {
name: string
search_name: string
avatar?: ContentURI
unread: number
highlighted: boolean
}
// eslint-disable-next-line no-misleading-character-class
@ -49,9 +53,13 @@ export function toSearchableString(str: string): string {
export class StateStore {
readonly rooms: Map<RoomID, RoomStateStore> = new Map()
readonly roomList = new NonNullCachedEventDispatcher<RoomListEntry[]>([])
switchRoom?: (roomID: RoomID) => void
imageAuthToken?: string
#roomListEntryChanged(entry: SyncRoom, oldEntry: RoomStateStore): boolean {
return entry.meta.sorting_timestamp !== oldEntry.meta.current.sorting_timestamp ||
entry.meta.unread_notifications !== oldEntry.meta.current.unread_notifications ||
entry.meta.unread_highlights !== oldEntry.meta.current.unread_highlights ||
entry.meta.preview_event_rowid !== oldEntry.meta.current.preview_event_rowid ||
entry.events.findIndex(evt => evt.rowid === entry.meta.preview_event_rowid) !== -1
}
@ -71,6 +79,8 @@ export class StateStore {
name,
search_name: toSearchableString(name),
avatar: entry.meta.avatar,
unread: entry.meta.unread_notifications,
highlighted: entry.meta.unread_highlights > 0,
}
}
@ -90,6 +100,12 @@ export class StateStore {
if (roomListEntryChanged) {
changedRoomListEntries.set(roomID, this.#makeRoomListEntry(data, room))
}
if (Notification.permission === "granted" && !focused.current) {
for (const notification of data.notifications) {
this.showNotification(room, notification.event_rowid, notification.sound)
}
}
}
let updatedRoomList: RoomListEntry[] | undefined
@ -117,6 +133,40 @@ export class StateStore {
}
}
showNotification(room: RoomStateStore, rowid: EventRowID, sound: boolean) {
const evt = room.eventsByRowID.get(rowid)
if (!evt || typeof evt.content.body !== "string") {
return
}
let body = evt.content.body
if (body.length > 200) {
body = body.slice(0, 150) + " […]"
}
const memberEvt = room.getStateEvent("m.room.member", evt.sender)
const icon = `${getMediaURL(memberEvt?.content.avatar_url)}&image_auth=${this.imageAuthToken}`
const roomName = room.meta.current.name ?? "Unnamed room"
const senderName = memberEvt?.content.displayname ?? evt.sender
const title = senderName === roomName ? senderName : `${senderName} (${roomName})`
const notif = new Notification(title, {
body,
icon,
badge: "/gomuks.png",
// timestamp: evt.timestamp,
// image: ...,
tag: rowid.toString(),
})
notif.onclick = () => this.onClickNotification(room.roomID)
if (sound) {
// TODO play sound
}
}
onClickNotification(roomID: RoomID) {
if (this.switchRoom) {
this.switchRoom(roomID)
}
}
applySendComplete(data: SendCompleteData) {
const room = this.rooms.get(data.event.room_id)
if (!room) {

View file

@ -147,6 +147,7 @@ export class RoomStateStore {
if (memEvt.last_edit) {
memEvt.orig_content = memEvt.content
memEvt.content = memEvt.last_edit.content["m.new_content"]
memEvt.local_content = memEvt.last_edit.local_content
}
} else if (memEvt.relation_type === "m.replace" && memEvt.relates_to) {
const editTarget = this.eventsByID.get(memEvt.relates_to)
@ -154,6 +155,7 @@ export class RoomStateStore {
editTarget.last_edit = memEvt
editTarget.orig_content = editTarget.content
editTarget.content = memEvt.content["m.new_content"]
editTarget.local_content = memEvt.local_content
}
}
this.eventsByRowID.set(memEvt.rowid, memEvt)

View file

@ -60,12 +60,22 @@ export interface EventsDecryptedEvent extends RPCCommand<EventsDecryptedData> {
command: "events_decrypted"
}
export interface ImageAuthTokenEvent extends RPCCommand<string> {
command: "image_auth_token"
}
export interface SyncRoom {
meta: DBRoom
timeline: TimelineRowTuple[]
events: RawDBEvent[]
state: Record<EventType, Record<string, EventRowID>>
reset: boolean
notifications: SyncNotification[]
}
export interface SyncNotification {
event_rowid: EventRowID
sound: boolean
}
export interface SyncCompleteData {
@ -97,4 +107,5 @@ export type RPCEvent =
TypingEvent |
SendCompleteEvent |
EventsDecryptedEvent |
SyncCompleteEvent
SyncCompleteEvent |
ImageAuthTokenEvent

View file

@ -59,6 +59,9 @@ export interface DBRoom {
preview_event_rowid: EventRowID
sorting_timestamp: number
unread_highlights: number
unread_notifications: number
unread_messages: number
prev_batch: string
}
@ -66,6 +69,18 @@ export interface DBRoom {
//eslint-disable-next-line @typescript-eslint/no-explicit-any
export type UnknownEventContent = Record<string, any>
export enum UnreadType {
None = 0b0000,
Normal = 0b0001,
Notify = 0b0010,
Highlight = 0b0100,
Sound = 0b1000,
}
export interface LocalContent {
sanitized_html?: TrustedHTML
}
export interface BaseDBEvent {
rowid: EventRowID
timeline_rowid: TimelineRowID
@ -79,6 +94,7 @@ export interface BaseDBEvent {
content: UnknownEventContent
unsigned: EventUnsigned
local_content?: LocalContent
transaction_id?: string
@ -91,6 +107,7 @@ export interface BaseDBEvent {
reactions?: Record<string, number>
last_edit_rowid?: EventRowID
unread_type: UnreadType
}
export interface RawDBEvent extends BaseDBEvent {

View file

@ -31,6 +31,7 @@ const MainScreen = () => {
.catch(err => console.error("Failed to load room state", err))
}
}, [client])
client.store.switchRoom = setActiveRoom
const clearActiveRoom = useCallback(() => setActiveRoomID(null), [])
return <main className={`matrix-main ${activeRoom ? "room-selected" : ""}`}>
<RoomList setActiveRoom={setActiveRoom} activeRoomID={activeRoomID} />

View file

@ -60,6 +60,11 @@ const Entry = ({ room, setActiveRoom, isActive, hidden }: RoomListEntryProps) =>
<div className="room-name">{room.name}</div>
{previewText && <div className="message-preview" title={previewText}>{croppedPreviewText}</div>}
</div>
{room.unread ? <div className="room-entry-unreads">
<div className={`unread-count ${room.highlighted ? "highlighted" : ""}`}>
{room.unread}
</div>
</div> : null}
</div>
}

View file

@ -78,6 +78,26 @@ div.room-entry {
}
}
}
> div.room-entry-unreads {
display: flex;
align-items: center;
width: 3rem;
> div.unread-count {
width: 1.5rem;
height: 1.5rem;
border-radius: 50%;
background-color: green;
text-align: center;
color: white;
font-weight: bold;
&.highlighted {
background-color: darkred;
}
}
}
}
img.avatar {

View file

@ -166,7 +166,20 @@ div.html-body {
vertical-align: middle;
}
span[data-mx-spoiler] {
img.hicli-custom-emoji {
vertical-align: middle;
height: 24px;
width: auto;
max-width: 72px;
}
img.hicli-sizeless-inline-img {
height: 24px;
width: auto;
max-width: 72px;
}
span[data-mx-spoiler], span.hicli-spoiler {
filter: blur(4px);
transition: filter .5s;
cursor: pointer;

View file

@ -24,7 +24,11 @@ import { ReplyIDBody } from "./ReplyBody.tsx"
import EncryptedBody from "./content/EncryptedBody.tsx"
import HiddenEvent from "./content/HiddenEvent.tsx"
import MemberBody from "./content/MemberBody.tsx"
import { MediaMessageBody, TextMessageBody, UnknownMessageBody } from "./content/MessageBody.tsx"
import {
MediaMessageBody,
TextMessageBody,
UnknownMessageBody,
} from "./content/MessageBody.tsx"
import RedactedBody from "./content/RedactedBody.tsx"
import { EventContentProps } from "./content/props.ts"
import ErrorIcon from "../../icons/error.svg?react"

View file

@ -13,11 +13,9 @@
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import { CSSProperties, use, useMemo } from "react"
import sanitizeHtml from "sanitize-html"
import { CSSProperties, use } from "react"
import { getEncryptedMediaURL, getMediaURL } from "@/api/media.ts"
import type { EventType, MediaMessageEventContent, MessageEventContent } from "@/api/types"
import { sanitizeHtmlParams } from "@/util/html.ts"
import { calculateMediaSize } from "@/util/mediasize.ts"
import { LightboxContext } from "../../Lightbox.tsx"
import { EventContentProps } from "./props.ts"
@ -32,14 +30,10 @@ const onClickHTML = (evt: React.MouseEvent<HTMLDivElement>) => {
export const TextMessageBody = ({ event }: EventContentProps) => {
const content = event.content as MessageEventContent
const __html = useMemo(() => {
if (content.format === "org.matrix.custom.html") {
return sanitizeHtml(content.formatted_body!, sanitizeHtmlParams)
}
return undefined
}, [content.format, content.formatted_body])
if (__html) {
return <div onClick={onClickHTML} className="message-text html-body" dangerouslySetInnerHTML={{ __html }}/>
if (event.local_content?.sanitized_html) {
return <div onClick={onClickHTML} className="message-text html-body" dangerouslySetInnerHTML={{
__html: event.local_content!.sanitized_html!,
}}/>
}
return <div className="message-text plaintext-body">{content.body}</div>
}

View file

@ -15,7 +15,7 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import { NonNullCachedEventDispatcher, useNonNullEventAsState } from "@/util/eventdispatcher.ts"
const focused = new NonNullCachedEventDispatcher(document.hasFocus())
export const focused = new NonNullCachedEventDispatcher(document.hasFocus())
window.addEventListener("focus", () => focused.emit(true))
window.addEventListener("blur", () => focused.emit(false))

View file

@ -28,10 +28,12 @@ import (
"github.com/coder/websocket"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/hicli"
"maunium.net/go/mautrix/hicli/database"
"maunium.net/go/mautrix/id"
"go.mau.fi/gomuks/pkg/hicli"
"go.mau.fi/gomuks/pkg/hicli/database"
)
func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand) error {
@ -120,6 +122,17 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
lastDataReceived := &atomic.Int64{}
lastDataReceived.Store(time.Now().UnixMilli())
const RecvTimeout = 60 * time.Second
lastImageAuthTokenSent := time.Now()
sendImageAuthToken := func() {
err := writeCmd(ctx, conn, &hicli.JSONCommand{
Command: "image_auth_token",
Data: exerrors.Must(json.Marshal(gmx.generateImageToken())),
})
if err != nil {
log.Err(err).Msg("Failed to write image auth token message")
return
}
}
go func() {
defer recoverPanic("event loop")
defer closeOnce.Do(forceClose)
@ -137,6 +150,10 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event")
}
case <-ticker.C:
if time.Since(lastImageAuthTokenSent) > 30*time.Minute {
sendImageAuthToken()
lastImageAuthTokenSent = time.Now()
}
if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() {
log.Warn().Msg("No data received in a minute, closing connection")
_ = conn.Close(StatusPingTimeout, "Ping timeout")
@ -187,6 +204,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
log.Err(initErr).Msg("Failed to write init message")
return
}
go sendImageAuthToken()
go gmx.sendInitialData(ctx, conn)
log.Debug().Msg("Connection initialization complete")
var closeErr websocket.CloseError
@ -242,10 +260,11 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
}
maxTS = room.SortingTimestamp.Time
syncRoom := &hicli.SyncRoom{
Meta: room,
Events: make([]*database.Event, 0, 2),
Timeline: make([]database.TimelineRowTuple, 0),
State: map[event.Type]map[string]database.EventRowID{},
Meta: room,
Events: make([]*database.Event, 0, 2),
Timeline: make([]database.TimelineRowTuple, 0),
State: map[event.Type]map[string]database.EventRowID{},
Notifications: make([]hicli.SyncNotification, 0),
}
payload.Rooms[room.ID] = syncRoom
if room.PreviewEventRowID != 0 {
@ -255,6 +274,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
return
}
if previewEvent != nil {
gmx.Client.ReprocessExistingEvent(ctx, previewEvent)
previewMember, err := gmx.Client.DB.CurrentState.Get(ctx, room.ID, event.StateMember, previewEvent.Sender.String())
if err != nil {
log.Err(err).Msg("Failed to get preview member event for room")
@ -269,6 +289,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
if err != nil {
log.Err(err).Msg("Failed to get last edit for preview event")
} else if lastEdit != nil {
gmx.Client.ReprocessExistingEvent(ctx, lastEdit)
syncRoom.Events = append(syncRoom.Events, lastEdit)
}
}