mirror of
https://github.com/tulir/gomuks.git
synced 2025-04-18 17:53:42 -05:00
all: move hicli from mautrix-go and add more features
This commit is contained in:
parent
d79be2b8cf
commit
1db1d2db5c
53 changed files with 6068 additions and 45 deletions
16
go.mod
16
go.mod
|
@ -5,33 +5,35 @@ go 1.23.0
|
|||
toolchain go1.23.2
|
||||
|
||||
require (
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/coder/websocket v1.8.12
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/rivo/uniseg v0.4.7
|
||||
github.com/rs/zerolog v1.33.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.7.7
|
||||
go.mau.fi/util v0.8.1
|
||||
go.mau.fi/zeroconfig v0.1.3
|
||||
golang.org/x/crypto v0.28.0
|
||||
golang.org/x/net v0.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455
|
||||
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297
|
||||
mvdan.cc/xurls/v2 v2.5.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/yuin/goldmark v1.7.7 // indirect
|
||||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect
|
||||
golang.org/x/net v0.30.0 // indirect
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
)
|
||||
|
|
13
go.sum
13
go.sum
|
@ -2,6 +2,12 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
|||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
|
||||
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
|
||||
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
|
||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||
|
@ -53,6 +59,7 @@ golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBn
|
|||
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8=
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
||||
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
@ -66,5 +73,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455 h1:6UznUe9nDojckcvBNq0h1vZFM/KmA7Xs24YowG8iE+4=
|
||||
maunium.net/go/mautrix v0.21.2-0.20241016143340-1d4c2d255455/go.mod h1:7F/S6XAdyc/6DW+Q7xyFXRSPb6IjfqMb1OMepQ8C8OE=
|
||||
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297 h1:8CybV+x9HPh4p41nIqJMKHI0bUF0g0bEozq49ytNVlc=
|
||||
maunium.net/go/mautrix v0.21.2-0.20241017173032-367828429297/go.mod h1:sjCZR1R/3NET/WjkcXPL6WpAHlWKku9HjRsdOkbM8Qw=
|
||||
mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8=
|
||||
mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE=
|
||||
|
|
|
@ -34,7 +34,8 @@ import (
|
|||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"maunium.net/go/mautrix/hicli"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli"
|
||||
)
|
||||
|
||||
type Gomuks struct {
|
||||
|
@ -144,6 +145,7 @@ func (gmx *Gomuks) SetupLog() {
|
|||
}
|
||||
|
||||
func (gmx *Gomuks) StartClient() {
|
||||
hicli.HTMLSanitizerImgSrcTemplate = "_gomuks/media/%s/%s"
|
||||
rawDB, err := dbutil.NewFromConfig("gomuks", dbutil.Config{
|
||||
PoolConfig: dbutil.PoolConfig{
|
||||
Type: "sqlite3-fk-wal",
|
||||
|
|
3
main.go
3
main.go
|
@ -27,7 +27,8 @@ import (
|
|||
_ "go.mau.fi/util/dbutil/litestream"
|
||||
flag "maunium.net/go/mauflag"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/hicli"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
4
media.go
4
media.go
|
@ -16,10 +16,10 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"go.mau.fi/util/ptr"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
var ErrBadGateway = mautrix.RespError{
|
||||
|
|
373
pkg/hicli/LICENSE
Normal file
373
pkg/hicli/LICENSE
Normal file
|
@ -0,0 +1,373 @@
|
|||
Mozilla Public License Version 2.0
|
||||
==================================
|
||||
|
||||
1. Definitions
|
||||
--------------
|
||||
|
||||
1.1. "Contributor"
|
||||
means each individual or legal entity that creates, contributes to
|
||||
the creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
means the combination of the Contributions of others (if any) used
|
||||
by a Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
means Source Code Form to which the initial Contributor has attached
|
||||
the notice in Exhibit A, the Executable Form of such Source Code
|
||||
Form, and Modifications of such Source Code Form, in each case
|
||||
including portions thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
(a) that the initial Contributor has attached the notice described
|
||||
in Exhibit B to the Covered Software; or
|
||||
|
||||
(b) that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the
|
||||
terms of a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
means a work that combines Covered Software with other material, in
|
||||
a separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
means having the right to grant, to the maximum extent possible,
|
||||
whether at the time of the initial grant or subsequently, any and
|
||||
all of the rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
means any of the following:
|
||||
|
||||
(a) any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered
|
||||
Software; or
|
||||
|
||||
(b) any new file in Source Code Form that contains any Covered
|
||||
Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the
|
||||
License, by the making, using, selling, offering for sale, having
|
||||
made, import, or transfer of either its Contributions or its
|
||||
Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
means either the GNU General Public License, Version 2.0, the GNU
|
||||
Lesser General Public License, Version 2.1, the GNU Affero General
|
||||
Public License, Version 3.0, or any later versions of those
|
||||
licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that
|
||||
controls, is controlled by, or is under common control with You. For
|
||||
purposes of this definition, "control" means (a) the power, direct
|
||||
or indirect, to cause the direction or management of such entity,
|
||||
whether by contract or otherwise, or (b) ownership of more than
|
||||
fifty percent (50%) of the outstanding shares or beneficial
|
||||
ownership of such entity.
|
||||
|
||||
2. License Grants and Conditions
|
||||
--------------------------------
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
(a) under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
||||
for sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
(a) for any code that a Contributor has removed from Covered Software;
|
||||
or
|
||||
|
||||
(b) for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
(c) under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights
|
||||
to grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
||||
in Section 2.1.
|
||||
|
||||
3. Responsibilities
|
||||
-------------------
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
(a) such Covered Software must also be made available in Source Code
|
||||
Form, as described in Section 3.1, and You must inform recipients of
|
||||
the Executable Form how they can obtain a copy of such Source Code
|
||||
Form by reasonable means in a timely manner, at a charge no more
|
||||
than the cost of distribution to the recipient; and
|
||||
|
||||
(b) You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter
|
||||
the recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty,
|
||||
or limitations of liability) contained within the Source Code Form of
|
||||
the Covered Software, except that You may alter any license notices to
|
||||
the extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
---------------------------------------------------
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this
|
||||
License with respect to some or all of the Covered Software due to
|
||||
statute, judicial order, or regulation then You must: (a) comply with
|
||||
the terms of this License to the maximum extent possible; and (b)
|
||||
describe the limitations and the code they affect. Such description must
|
||||
be placed in a text file included with all distributions of the Covered
|
||||
Software under this License. Except to the extent prohibited by statute
|
||||
or regulation, such description must be sufficiently detailed for a
|
||||
recipient of ordinary skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
--------------
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically
|
||||
if You fail to comply with any of its terms. However, if You become
|
||||
compliant, then the rights granted under this License from a particular
|
||||
Contributor are reinstated (a) provisionally, unless and until such
|
||||
Contributor explicitly and finally terminates Your grants, and (b) on an
|
||||
ongoing basis, if such Contributor fails to notify You of the
|
||||
non-compliance by some reasonable means prior to 60 days after You have
|
||||
come back into compliance. Moreover, Your grants from a particular
|
||||
Contributor are reinstated on an ongoing basis if such Contributor
|
||||
notifies You of the non-compliance by some reasonable means, this is the
|
||||
first time You have received notice of non-compliance with this License
|
||||
from such Contributor, and You become compliant prior to 30 days after
|
||||
Your receipt of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
||||
end user license agreements (excluding distributors and resellers) which
|
||||
have been validly granted by You or Your distributors under this License
|
||||
prior to termination shall survive termination.
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 6. Disclaimer of Warranty *
|
||||
* ------------------------- *
|
||||
* *
|
||||
* Covered Software is provided under this License on an "as is" *
|
||||
* basis, without warranty of any kind, either expressed, implied, or *
|
||||
* statutory, including, without limitation, warranties that the *
|
||||
* Covered Software is free of defects, merchantable, fit for a *
|
||||
* particular purpose or non-infringing. The entire risk as to the *
|
||||
* quality and performance of the Covered Software is with You. *
|
||||
* Should any Covered Software prove defective in any respect, You *
|
||||
* (not any Contributor) assume the cost of any necessary servicing, *
|
||||
* repair, or correction. This disclaimer of warranty constitutes an *
|
||||
* essential part of this License. No use of any Covered Software is *
|
||||
* authorized under this License except under this disclaimer. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 7. Limitation of Liability *
|
||||
* -------------------------- *
|
||||
* *
|
||||
* Under no circumstances and under no legal theory, whether tort *
|
||||
* (including negligence), contract, or otherwise, shall any *
|
||||
* Contributor, or anyone who distributes Covered Software as *
|
||||
* permitted above, be liable to You for any direct, indirect, *
|
||||
* special, incidental, or consequential damages of any character *
|
||||
* including, without limitation, damages for lost profits, loss of *
|
||||
* goodwill, work stoppage, computer failure or malfunction, or any *
|
||||
* and all other commercial damages or losses, even if such party *
|
||||
* shall have been informed of the possibility of such damages. This *
|
||||
* limitation of liability shall not apply to liability for death or *
|
||||
* personal injury resulting from such party's negligence to the *
|
||||
* extent applicable law prohibits such limitation. Some *
|
||||
* jurisdictions do not allow the exclusion or limitation of *
|
||||
* incidental or consequential damages, so this exclusion and *
|
||||
* limitation may not apply to You. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
8. Litigation
|
||||
-------------
|
||||
|
||||
Any litigation relating to this License may be brought only in the
|
||||
courts of a jurisdiction where the defendant maintains its principal
|
||||
place of business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions.
|
||||
Nothing in this Section shall prevent a party's ability to bring
|
||||
cross-claims or counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
----------------
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides
|
||||
that the language of a contract shall be construed against the drafter
|
||||
shall not be used to construe this License against a Contributor.
|
||||
|
||||
10. Versions of the License
|
||||
---------------------------
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses
|
||||
|
||||
If You choose to distribute Source Code Form that is Incompatible With
|
||||
Secondary Licenses under the terms of this version of the License, the
|
||||
notice described in Exhibit B of this License must be attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
-------------------------------------------
|
||||
|
||||
This Source Code Form is subject to the terms of the Mozilla Public
|
||||
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular
|
||||
file, then You may include the notice in a location (such as a LICENSE
|
||||
file in a relevant directory) where a recipient would be likely to look
|
||||
for such a notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
---------------------------------------------------------
|
||||
|
||||
This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||
defined by the Mozilla Public License, v. 2.0.
|
64
pkg/hicli/cryptohelper.go
Normal file
64
pkg/hicli/cryptohelper.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type hiCryptoHelper HiClient
|
||||
|
||||
var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil)
|
||||
|
||||
func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) {
|
||||
roomMeta, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get room metadata: %w", err)
|
||||
} else if roomMeta == nil {
|
||||
return nil, fmt.Errorf("unknown room")
|
||||
}
|
||||
return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content)
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
return h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
|
||||
userID: {deviceID},
|
||||
h.Account.UserID: {"*"},
|
||||
})
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("session_id", sessionID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Failed to send room key request")
|
||||
} else {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("session_id", sessionID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Sent room key request")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hiCryptoHelper) Init(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
73
pkg/hicli/database/account.go
Normal file
73
pkg/hicli/database/account.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1`
|
||||
putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2`
|
||||
upsertAccountQuery = `
|
||||
INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch)
|
||||
VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id)
|
||||
DO UPDATE SET device_id = excluded.device_id,
|
||||
access_token = excluded.access_token,
|
||||
homeserver_url = excluded.homeserver_url,
|
||||
next_batch = excluded.next_batch
|
||||
`
|
||||
)
|
||||
|
||||
type AccountQuery struct {
|
||||
*dbutil.QueryHelper[*Account]
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) {
|
||||
var exists bool
|
||||
if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists {
|
||||
return
|
||||
}
|
||||
err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) {
|
||||
return aq.QueryOne(ctx, getAccountQuery, userID)
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error {
|
||||
return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID)
|
||||
}
|
||||
|
||||
func (aq *AccountQuery) Put(ctx context.Context, account *Account) error {
|
||||
return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...)
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
UserID id.UserID
|
||||
DeviceID id.DeviceID
|
||||
AccessToken string
|
||||
HomeserverURL string
|
||||
NextBatch string
|
||||
}
|
||||
|
||||
func (a *Account) Scan(row dbutil.Scannable) (*Account, error) {
|
||||
return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch))
|
||||
}
|
||||
|
||||
func (a *Account) sqlVariables() []any {
|
||||
return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch}
|
||||
}
|
67
pkg/hicli/database/accountdata.go
Normal file
67
pkg/hicli/database/accountdata.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"unsafe"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
upsertAccountDataQuery = `
|
||||
INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content
|
||||
`
|
||||
upsertRoomAccountDataQuery = `
|
||||
INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content
|
||||
`
|
||||
)
|
||||
|
||||
type AccountDataQuery struct {
|
||||
*dbutil.QueryHelper[*AccountData]
|
||||
}
|
||||
|
||||
func unsafeJSONString(content json.RawMessage) *string {
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
str := unsafe.String(unsafe.SliceData(content), len(content))
|
||||
return &str
|
||||
}
|
||||
|
||||
func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error {
|
||||
return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content))
|
||||
}
|
||||
|
||||
func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error {
|
||||
return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content))
|
||||
}
|
||||
|
||||
type AccountData struct {
|
||||
UserID id.UserID `json:"user_id"`
|
||||
RoomID id.RoomID `json:"room_id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) {
|
||||
var roomID sql.NullString
|
||||
err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.RoomID = id.RoomID(roomID.String)
|
||||
return a, nil
|
||||
}
|
149
pkg/hicli/database/cachedmedia.go
Normal file
149
pkg/hicli/database/cachedmedia.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/attachment"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
insertCachedMediaQuery = `
|
||||
INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (mxc) DO NOTHING
|
||||
`
|
||||
upsertCachedMediaQuery = `
|
||||
INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (mxc) DO UPDATE
|
||||
SET enc_file = excluded.enc_file,
|
||||
file_name = excluded.file_name,
|
||||
mime_type = excluded.mime_type,
|
||||
size = excluded.size,
|
||||
hash = excluded.hash,
|
||||
error = excluded.error
|
||||
WHERE excluded.error IS NULL OR cached_media.hash IS NULL
|
||||
`
|
||||
getCachedMediaQuery = `
|
||||
SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error
|
||||
FROM cached_media
|
||||
WHERE mxc = $1
|
||||
`
|
||||
)
|
||||
|
||||
type CachedMediaQuery struct {
|
||||
*dbutil.QueryHelper[*CachedMedia]
|
||||
}
|
||||
|
||||
func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error {
|
||||
return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error {
|
||||
return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) {
|
||||
return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc)
|
||||
}
|
||||
|
||||
type MediaError struct {
|
||||
Matrix *mautrix.RespError `json:"data"`
|
||||
StatusCode int `json:"status_code"`
|
||||
ReceivedAt jsontime.UnixMilli `json:"received_at"`
|
||||
Attempts int `json:"attempts"`
|
||||
}
|
||||
|
||||
const MaxMediaBackoff = 7 * 24 * time.Hour
|
||||
|
||||
func (me *MediaError) backoff() time.Duration {
|
||||
return min(time.Duration(2<<me.Attempts)*time.Second, MaxMediaBackoff)
|
||||
}
|
||||
|
||||
func (me *MediaError) UseCache() bool {
|
||||
return me != nil && time.Since(me.ReceivedAt.Time) < me.backoff()
|
||||
}
|
||||
|
||||
func (me *MediaError) Write(w http.ResponseWriter) {
|
||||
if me.Matrix.ExtraData == nil {
|
||||
me.Matrix.ExtraData = make(map[string]any)
|
||||
}
|
||||
me.Matrix.ExtraData["fi.mau.hicli.error_ts"] = me.ReceivedAt.UnixMilli()
|
||||
me.Matrix.ExtraData["fi.mau.hicli.next_retry_ts"] = me.ReceivedAt.Add(me.backoff()).UnixMilli()
|
||||
me.Matrix.WithStatus(me.StatusCode).Write(w)
|
||||
}
|
||||
|
||||
type CachedMedia struct {
|
||||
MXC id.ContentURI
|
||||
EventRowID EventRowID
|
||||
EncFile *attachment.EncryptedFile
|
||||
FileName string
|
||||
MimeType string
|
||||
Size int64
|
||||
Hash *[32]byte
|
||||
Error *MediaError
|
||||
}
|
||||
|
||||
func (c *CachedMedia) UseCache() bool {
|
||||
return c != nil && (c.Hash != nil || c.Error.UseCache())
|
||||
}
|
||||
|
||||
func (c *CachedMedia) sqlVariables() []any {
|
||||
var hash []byte
|
||||
if c.Hash != nil {
|
||||
hash = c.Hash[:]
|
||||
}
|
||||
return []any{
|
||||
&c.MXC, dbutil.NumPtr(c.EventRowID), dbutil.JSONPtr(c.EncFile),
|
||||
dbutil.StrPtr(c.FileName), dbutil.StrPtr(c.MimeType), dbutil.NumPtr(c.Size),
|
||||
hash, dbutil.JSONPtr(c.Error),
|
||||
}
|
||||
}
|
||||
|
||||
var safeMimes = []string{
|
||||
"text/css", "text/plain", "text/csv",
|
||||
"application/json", "application/ld+json",
|
||||
"image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif",
|
||||
"video/mp4", "video/webm", "video/ogg", "video/quicktime",
|
||||
"audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave",
|
||||
"audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac",
|
||||
}
|
||||
|
||||
func (c *CachedMedia) Scan(row dbutil.Scannable) (*CachedMedia, error) {
|
||||
var mimeType, fileName sql.NullString
|
||||
var size, eventRowID sql.NullInt64
|
||||
var hash []byte
|
||||
err := row.Scan(&c.MXC, &eventRowID, dbutil.JSON{Data: &c.EncFile}, &fileName, &mimeType, &size, &hash, dbutil.JSON{Data: &c.Error})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.MimeType = mimeType.String
|
||||
c.FileName = fileName.String
|
||||
c.EventRowID = EventRowID(eventRowID.Int64)
|
||||
c.Size = size.Int64
|
||||
if len(hash) == 32 {
|
||||
c.Hash = (*[32]byte)(hash)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *CachedMedia) ContentDisposition() string {
|
||||
if slices.Contains(safeMimes, c.MimeType) {
|
||||
return "inline"
|
||||
}
|
||||
return "attachment"
|
||||
}
|
73
pkg/hicli/database/database.go
Normal file
73
pkg/hicli/database/database.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database/upgrades"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
*dbutil.Database
|
||||
|
||||
Account AccountQuery
|
||||
AccountData AccountDataQuery
|
||||
Room RoomQuery
|
||||
Event EventQuery
|
||||
CurrentState CurrentStateQuery
|
||||
Timeline TimelineQuery
|
||||
SessionRequest SessionRequestQuery
|
||||
Receipt ReceiptQuery
|
||||
CachedMedia CachedMediaQuery
|
||||
}
|
||||
|
||||
func New(rawDB *dbutil.Database) *Database {
|
||||
rawDB.UpgradeTable = upgrades.Table
|
||||
eventQH := dbutil.MakeQueryHelper(rawDB, newEvent)
|
||||
return &Database{
|
||||
Database: rawDB,
|
||||
|
||||
Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)},
|
||||
AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)},
|
||||
Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)},
|
||||
Event: EventQuery{QueryHelper: eventQH},
|
||||
CurrentState: CurrentStateQuery{QueryHelper: eventQH},
|
||||
Timeline: TimelineQuery{QueryHelper: eventQH},
|
||||
SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)},
|
||||
Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)},
|
||||
CachedMedia: CachedMediaQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newCachedMedia)},
|
||||
}
|
||||
}
|
||||
|
||||
func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest {
|
||||
return &SessionRequest{}
|
||||
}
|
||||
|
||||
func newEvent(_ *dbutil.QueryHelper[*Event]) *Event {
|
||||
return &Event{}
|
||||
}
|
||||
|
||||
func newRoom(_ *dbutil.QueryHelper[*Room]) *Room {
|
||||
return &Room{}
|
||||
}
|
||||
|
||||
func newReceipt(_ *dbutil.QueryHelper[*Receipt]) *Receipt {
|
||||
return &Receipt{}
|
||||
}
|
||||
|
||||
func newCachedMedia(_ *dbutil.QueryHelper[*CachedMedia]) *CachedMedia {
|
||||
return &CachedMedia{}
|
||||
}
|
||||
|
||||
func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData {
|
||||
return &AccountData{}
|
||||
}
|
||||
|
||||
func newAccount(_ *dbutil.QueryHelper[*Account]) *Account {
|
||||
return &Account{}
|
||||
}
|
509
pkg/hicli/database/event.go
Normal file
509
pkg/hicli/database/event.go
Normal file
|
@ -0,0 +1,509 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exgjson"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getEventBaseQuery = `
|
||||
SELECT rowid, -1,
|
||||
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM event
|
||||
`
|
||||
getEventByRowID = getEventBaseQuery + `WHERE rowid = $1`
|
||||
getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)`
|
||||
getEventByID = getEventBaseQuery + `WHERE event_id = $1`
|
||||
getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL`
|
||||
insertEventBaseQuery = `
|
||||
INSERT INTO event (
|
||||
room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
|
||||
`
|
||||
insertEventQuery = insertEventBaseQuery + `RETURNING rowid`
|
||||
upsertEventQuery = insertEventBaseQuery + `
|
||||
ON CONFLICT (event_id) DO UPDATE
|
||||
SET decrypted=COALESCE(event.decrypted, excluded.decrypted),
|
||||
decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type),
|
||||
redacted_by=COALESCE(event.redacted_by, excluded.redacted_by),
|
||||
decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END,
|
||||
send_error=excluded.send_error,
|
||||
timestamp=excluded.timestamp,
|
||||
unsigned=COALESCE(excluded.unsigned, event.unsigned),
|
||||
local_content=COALESCE(excluded.local_content, event.local_content)
|
||||
ON CONFLICT (transaction_id) DO UPDATE
|
||||
SET event_id=excluded.event_id,
|
||||
timestamp=excluded.timestamp,
|
||||
unsigned=excluded.unsigned
|
||||
RETURNING rowid
|
||||
`
|
||||
updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1`
|
||||
updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1`
|
||||
updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1`
|
||||
updateEventLocalContentQuery = `UPDATE event SET local_content = $2 WHERE rowid = $1`
|
||||
getEventReactionsQuery = getEventBaseQuery + `
|
||||
WHERE room_id = ?
|
||||
AND type = 'm.reaction'
|
||||
AND relation_type = 'm.annotation'
|
||||
AND redacted_by IS NULL
|
||||
AND relates_to IN (%s)
|
||||
`
|
||||
getEventEditRowIDsQuery = `
|
||||
SELECT main.event_id, edit.rowid
|
||||
FROM event main
|
||||
JOIN event edit ON
|
||||
edit.room_id = main.room_id
|
||||
AND edit.relates_to = main.event_id
|
||||
AND edit.relation_type = 'm.replace'
|
||||
AND edit.type = main.type
|
||||
AND edit.sender = main.sender
|
||||
AND edit.redacted_by IS NULL
|
||||
WHERE main.event_id IN (%s)
|
||||
ORDER BY main.event_id, edit.timestamp
|
||||
`
|
||||
setLastEditRowIDQuery = `
|
||||
UPDATE event SET last_edit_rowid = $2 WHERE event_id = $1
|
||||
`
|
||||
updateReactionCountsQuery = `UPDATE event SET reactions = $2 WHERE event_id = $1`
|
||||
)
|
||||
|
||||
type EventQuery struct {
|
||||
*dbutil.QueryHelper[*Event]
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) {
|
||||
return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, error) {
|
||||
return eq.QueryOne(ctx, getEventByID, eventID)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetByRowID(ctx context.Context, rowID EventRowID) (*Event, error) {
|
||||
return eq.QueryOne(ctx, getEventByRowID, rowID)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) {
|
||||
query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID)
|
||||
return eq.QueryMany(ctx, query, params...)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
|
||||
err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID)
|
||||
if err == nil {
|
||||
evt.RowID = rowID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, err error) {
|
||||
err = eq.GetDB().QueryRow(ctx, insertEventQuery, evt.sqlVariables()...).Scan(&rowID)
|
||||
if err == nil {
|
||||
evt.RowID = rowID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error {
|
||||
return eq.Exec(ctx, updateEventIDQuery, rowID, newID)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sendError string) error {
|
||||
return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error {
|
||||
return eq.Exec(
|
||||
ctx,
|
||||
updateEventDecryptedQuery,
|
||||
evt.RowID,
|
||||
unsafeJSONString(evt.Decrypted),
|
||||
evt.DecryptedType,
|
||||
evt.UnreadType,
|
||||
dbutil.JSONPtr(evt.LocalContent),
|
||||
)
|
||||
}
|
||||
|
||||
func (eq *EventQuery) UpdateLocalContent(ctx context.Context, evt *Event) error {
|
||||
return eq.Exec(ctx, updateEventLocalContentQuery, evt.RowID, dbutil.JSONPtr(evt.LocalContent))
|
||||
}
|
||||
|
||||
func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error {
|
||||
eventIDs := make([]id.EventID, 0)
|
||||
eventMap := make(map[id.EventID]*Event)
|
||||
for i, evt := range events {
|
||||
if evt.Reactions == nil {
|
||||
eventIDs[i] = evt.ID
|
||||
eventMap[evt.ID] = evt
|
||||
}
|
||||
}
|
||||
result, err := eq.GetReactions(ctx, roomID, eventIDs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for evtID, res := range result {
|
||||
eventMap[evtID].Reactions = res.Counts
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error {
|
||||
eventIDs := make([]id.EventID, len(events))
|
||||
eventMap := make(map[id.EventID]*Event)
|
||||
for i, evt := range events {
|
||||
if evt.LastEditRowID == nil {
|
||||
eventIDs[i] = evt.ID
|
||||
eventMap[evt.ID] = evt
|
||||
}
|
||||
}
|
||||
return eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
result, err := eq.GetEditRowIDs(ctx, roomID, eventIDs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for evtID, res := range result {
|
||||
lastEditRowID := res[len(res)-1]
|
||||
eventMap[evtID].LastEditRowID = &lastEditRowID
|
||||
delete(eventMap, evtID)
|
||||
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var zero EventRowID
|
||||
for evtID, evt := range eventMap {
|
||||
evt.LastEditRowID = &zero
|
||||
err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, zero)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var reactionKeyPath = exgjson.Path("m.relates_to", "key")
|
||||
|
||||
type GetReactionsResult struct {
|
||||
Events []*Event
|
||||
Counts map[string]int
|
||||
}
|
||||
|
||||
func buildMultiEventGetFunction[T any](preParams []any, eventIDs []T, query string) (string, []any) {
|
||||
params := make([]any, len(preParams)+len(eventIDs))
|
||||
copy(params, preParams)
|
||||
for i, evtID := range eventIDs {
|
||||
params[i+len(preParams)] = evtID
|
||||
}
|
||||
placeholders := strings.Repeat("?,", len(eventIDs))
|
||||
placeholders = placeholders[:len(placeholders)-1]
|
||||
return fmt.Sprintf(query, placeholders), params
|
||||
}
|
||||
|
||||
type editRowIDTuple struct {
|
||||
eventID id.EventID
|
||||
editRowID EventRowID
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) {
|
||||
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventEditRowIDsQuery)
|
||||
rows, err := eq.GetDB().Query(ctx, query, params...)
|
||||
output := make(map[id.EventID][]EventRowID)
|
||||
return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) {
|
||||
err = row.Scan(&tuple.eventID, &tuple.editRowID)
|
||||
return
|
||||
}, err).Iter(func(tuple editRowIDTuple) (bool, error) {
|
||||
output[tuple.eventID] = append(output[tuple.eventID], tuple.editRowID)
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID]*GetReactionsResult, error) {
|
||||
result := make(map[id.EventID]*GetReactionsResult, len(eventIDs))
|
||||
for _, evtID := range eventIDs {
|
||||
result[evtID] = &GetReactionsResult{Counts: make(map[string]int)}
|
||||
}
|
||||
return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventReactionsQuery)
|
||||
events, err := eq.QueryMany(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, evt := range events {
|
||||
dest := result[evt.RelatesTo]
|
||||
dest.Events = append(dest.Events, evt)
|
||||
keyRes := gjson.GetBytes(evt.Content, reactionKeyPath)
|
||||
if keyRes.Type == gjson.String {
|
||||
dest.Counts[keyRes.Str]++
|
||||
}
|
||||
}
|
||||
for evtID, res := range result {
|
||||
if len(res.Counts) > 0 {
|
||||
err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
type EventRowID int64
|
||||
|
||||
func (m EventRowID) GetMassInsertValues() [1]any {
|
||||
return [1]any{m}
|
||||
}
|
||||
|
||||
type LocalContent struct {
|
||||
SanitizedHTML string `json:"sanitized_html,omitempty"`
|
||||
HTMLVersion int `json:"html_version,omitempty"`
|
||||
}
|
||||
|
||||
type UnreadType int
|
||||
|
||||
func (ut UnreadType) Is(flag UnreadType) bool {
|
||||
return ut&flag != 0
|
||||
}
|
||||
|
||||
const (
|
||||
UnreadTypeNone UnreadType = 0b0000
|
||||
UnreadTypeNormal UnreadType = 0b0001
|
||||
UnreadTypeNotify UnreadType = 0b0010
|
||||
UnreadTypeHighlight UnreadType = 0b0100
|
||||
UnreadTypeSound UnreadType = 0b1000
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
RowID EventRowID `json:"rowid"`
|
||||
TimelineRowID TimelineRowID `json:"timeline_rowid"`
|
||||
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
ID id.EventID `json:"event_id"`
|
||||
Sender id.UserID `json:"sender"`
|
||||
Type string `json:"type"`
|
||||
StateKey *string `json:"state_key,omitempty"`
|
||||
Timestamp jsontime.UnixMilli `json:"timestamp"`
|
||||
|
||||
Content json.RawMessage `json:"content"`
|
||||
Decrypted json.RawMessage `json:"decrypted,omitempty"`
|
||||
DecryptedType string `json:"decrypted_type,omitempty"`
|
||||
Unsigned json.RawMessage `json:"unsigned,omitempty"`
|
||||
LocalContent *LocalContent `json:"local_content,omitempty"`
|
||||
|
||||
TransactionID string `json:"transaction_id,omitempty"`
|
||||
|
||||
RedactedBy id.EventID `json:"redacted_by,omitempty"`
|
||||
RelatesTo id.EventID `json:"relates_to,omitempty"`
|
||||
RelationType event.RelationType `json:"relation_type,omitempty"`
|
||||
|
||||
MegolmSessionID id.SessionID `json:"-,omitempty"`
|
||||
DecryptionError string `json:"decryption_error,omitempty"`
|
||||
SendError string `json:"send_error,omitempty"`
|
||||
|
||||
Reactions map[string]int `json:"reactions,omitempty"`
|
||||
LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"`
|
||||
UnreadType UnreadType `json:"unread_type,omitempty"`
|
||||
}
|
||||
|
||||
func MautrixToEvent(evt *event.Event) *Event {
|
||||
dbEvt := &Event{
|
||||
RoomID: evt.RoomID,
|
||||
ID: evt.ID,
|
||||
Sender: evt.Sender,
|
||||
Type: evt.Type.Type,
|
||||
StateKey: evt.StateKey,
|
||||
Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)),
|
||||
Content: evt.Content.VeryRaw,
|
||||
MegolmSessionID: getMegolmSessionID(evt),
|
||||
TransactionID: evt.Unsigned.TransactionID,
|
||||
}
|
||||
if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") {
|
||||
dbEvt.TransactionID = ""
|
||||
}
|
||||
dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt)
|
||||
dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned)
|
||||
if evt.Unsigned.RedactedBecause != nil {
|
||||
dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID
|
||||
}
|
||||
return dbEvt
|
||||
}
|
||||
|
||||
func (e *Event) AsRawMautrix() *event.Event {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
evt := &event.Event{
|
||||
RoomID: e.RoomID,
|
||||
ID: e.ID,
|
||||
Sender: e.Sender,
|
||||
Type: event.Type{Type: e.Type, Class: event.MessageEventType},
|
||||
StateKey: e.StateKey,
|
||||
Timestamp: e.Timestamp.UnixMilli(),
|
||||
Content: event.Content{VeryRaw: e.Content},
|
||||
}
|
||||
if e.Decrypted != nil {
|
||||
evt.Content.VeryRaw = e.Decrypted
|
||||
evt.Type.Type = e.DecryptedType
|
||||
evt.Mautrix.WasEncrypted = true
|
||||
}
|
||||
if e.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
}
|
||||
_ = json.Unmarshal(e.Unsigned, &evt.Unsigned)
|
||||
return evt
|
||||
}
|
||||
|
||||
func (e *Event) Scan(row dbutil.Scannable) (*Event, error) {
|
||||
var timestamp int64
|
||||
var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString
|
||||
err := row.Scan(
|
||||
&e.RowID,
|
||||
&e.TimelineRowID,
|
||||
&e.RoomID,
|
||||
&e.ID,
|
||||
&e.Sender,
|
||||
&e.Type,
|
||||
&e.StateKey,
|
||||
×tamp,
|
||||
(*[]byte)(&e.Content),
|
||||
(*[]byte)(&e.Decrypted),
|
||||
&decryptedType,
|
||||
(*[]byte)(&e.Unsigned),
|
||||
dbutil.JSON{Data: &e.LocalContent},
|
||||
&transactionID,
|
||||
&redactedBy,
|
||||
&relatesTo,
|
||||
&relationType,
|
||||
&megolmSessionID,
|
||||
&decryptionError,
|
||||
&sendError,
|
||||
dbutil.JSON{Data: &e.Reactions},
|
||||
&e.LastEditRowID,
|
||||
&e.UnreadType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.Timestamp = jsontime.UM(time.UnixMilli(timestamp))
|
||||
e.TransactionID = transactionID.String
|
||||
e.RedactedBy = id.EventID(redactedBy.String)
|
||||
e.RelatesTo = id.EventID(relatesTo.String)
|
||||
e.RelationType = event.RelationType(relationType.String)
|
||||
e.MegolmSessionID = id.SessionID(megolmSessionID.String)
|
||||
e.DecryptedType = decryptedType.String
|
||||
e.DecryptionError = decryptionError.String
|
||||
e.SendError = sendError.String
|
||||
return e, nil
|
||||
}
|
||||
|
||||
var relatesToPath = exgjson.Path("m.relates_to", "event_id")
|
||||
var relationTypePath = exgjson.Path("m.relates_to", "rel_type")
|
||||
|
||||
func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) {
|
||||
if evt.StateKey != nil {
|
||||
return "", ""
|
||||
}
|
||||
return GetRelatesToFromBytes(evt.Content.VeryRaw)
|
||||
}
|
||||
|
||||
func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) {
|
||||
results := gjson.GetManyBytes(content, relatesToPath, relationTypePath)
|
||||
if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String {
|
||||
return id.EventID(results[0].Str), event.RelationType(results[1].Str)
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func getMegolmSessionID(evt *event.Event) id.SessionID {
|
||||
if evt.Type != event.EventEncrypted {
|
||||
return ""
|
||||
}
|
||||
res := gjson.GetBytes(evt.Content.VeryRaw, "session_id")
|
||||
if res.Exists() && res.Type == gjson.String {
|
||||
return id.SessionID(res.Str)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *Event) sqlVariables() []any {
|
||||
var reactions any
|
||||
if e.Reactions != nil {
|
||||
reactions = e.Reactions
|
||||
}
|
||||
return []any{
|
||||
e.RoomID,
|
||||
e.ID,
|
||||
e.Sender,
|
||||
e.Type,
|
||||
e.StateKey,
|
||||
e.Timestamp.UnixMilli(),
|
||||
unsafeJSONString(e.Content),
|
||||
unsafeJSONString(e.Decrypted),
|
||||
dbutil.StrPtr(e.DecryptedType),
|
||||
unsafeJSONString(e.Unsigned),
|
||||
dbutil.JSONPtr(e.LocalContent),
|
||||
dbutil.StrPtr(e.TransactionID),
|
||||
dbutil.StrPtr(e.RedactedBy),
|
||||
dbutil.StrPtr(e.RelatesTo),
|
||||
dbutil.StrPtr(e.RelationType),
|
||||
dbutil.StrPtr(e.MegolmSessionID),
|
||||
dbutil.StrPtr(e.DecryptionError),
|
||||
dbutil.StrPtr(e.SendError),
|
||||
dbutil.JSON{Data: reactions},
|
||||
e.LastEditRowID,
|
||||
e.UnreadType,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Event) GetNonPushUnreadType() UnreadType {
|
||||
if e.RelationType == event.RelReplace {
|
||||
return UnreadTypeNone
|
||||
}
|
||||
switch e.Type {
|
||||
case event.EventMessage.Type, event.EventSticker.Type:
|
||||
return UnreadTypeNormal
|
||||
case event.EventEncrypted.Type:
|
||||
switch e.DecryptedType {
|
||||
case event.EventMessage.Type, event.EventSticker.Type:
|
||||
return UnreadTypeNormal
|
||||
}
|
||||
}
|
||||
return UnreadTypeNone
|
||||
}
|
||||
|
||||
func (e *Event) CanUseForPreview() bool {
|
||||
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type ||
|
||||
(e.Type == event.EventEncrypted.Type &&
|
||||
(e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) &&
|
||||
e.RelationType != event.RelReplace && e.RedactedBy == ""
|
||||
}
|
||||
|
||||
func (e *Event) BumpsSortingTimestamp() bool {
|
||||
return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) &&
|
||||
e.RelationType != event.RelReplace
|
||||
}
|
81
pkg/hicli/database/receipt.go
Normal file
81
pkg/hicli/database/receipt.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
upsertReceiptQuery = `
|
||||
INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE
|
||||
SET event_id = excluded.event_id,
|
||||
timestamp = excluded.timestamp
|
||||
`
|
||||
)
|
||||
|
||||
var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)")
|
||||
|
||||
type ReceiptQuery struct {
|
||||
*dbutil.QueryHelper[*Receipt]
|
||||
}
|
||||
|
||||
func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error {
|
||||
return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error {
|
||||
if len(receipts) > 1000 {
|
||||
return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
for receiptChunk := range slices.Chunk(receipts, 200) {
|
||||
err := rq.PutMany(ctx, roomID, receiptChunk...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
query, params := receiptMassInserter.Build([1]any{roomID}, receipts)
|
||||
return rq.Exec(ctx, query, params...)
|
||||
}
|
||||
|
||||
type Receipt struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
UserID id.UserID `json:"user_id"`
|
||||
ReceiptType event.ReceiptType `json:"receipt_type"`
|
||||
ThreadID event.ThreadID `json:"thread_id"`
|
||||
EventID id.EventID `json:"event_id"`
|
||||
Timestamp jsontime.UnixMilli `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) {
|
||||
var ts int64
|
||||
err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Timestamp = jsontime.UM(time.UnixMilli(ts))
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Receipt) sqlVariables() []any {
|
||||
return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
|
||||
}
|
||||
|
||||
func (r *Receipt) GetMassInsertValues() [5]any {
|
||||
return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()}
|
||||
}
|
278
pkg/hicli/database/room.go
Normal file
278
pkg/hicli/database/room.go
Normal file
|
@ -0,0 +1,278 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getRoomBaseQuery = `
|
||||
SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias,
|
||||
lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp,
|
||||
unread_highlights, unread_notifications, unread_messages, prev_batch
|
||||
FROM room
|
||||
`
|
||||
getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2`
|
||||
getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1`
|
||||
ensureRoomExistsQuery = `
|
||||
INSERT INTO room (room_id) VALUES ($1)
|
||||
ON CONFLICT (room_id) DO NOTHING
|
||||
`
|
||||
upsertRoomFromSyncQuery = `
|
||||
UPDATE room
|
||||
SET creation_content = COALESCE(room.creation_content, $2),
|
||||
name = COALESCE($3, room.name),
|
||||
name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END,
|
||||
avatar = COALESCE($5, room.avatar),
|
||||
explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END,
|
||||
topic = COALESCE($7, room.topic),
|
||||
canonical_alias = COALESCE($8, room.canonical_alias),
|
||||
lazy_load_summary = COALESCE($9, room.lazy_load_summary),
|
||||
encryption_event = COALESCE($10, room.encryption_event),
|
||||
has_member_list = room.has_member_list OR $11,
|
||||
preview_event_rowid = COALESCE($12, room.preview_event_rowid),
|
||||
sorting_timestamp = COALESCE($13, room.sorting_timestamp),
|
||||
unread_highlights = COALESCE($14, room.unread_highlights),
|
||||
unread_notifications = COALESCE($15, room.unread_notifications),
|
||||
unread_messages = COALESCE($16, room.unread_messages),
|
||||
prev_batch = COALESCE($17, room.prev_batch)
|
||||
WHERE room_id = $1
|
||||
`
|
||||
setRoomPrevBatchQuery = `
|
||||
UPDATE room SET prev_batch = $2 WHERE room_id = $1
|
||||
`
|
||||
updateRoomPreviewIfLaterOnTimelineQuery = `
|
||||
UPDATE room
|
||||
SET preview_event_rowid = $2
|
||||
WHERE room_id = $1
|
||||
AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1)
|
||||
> COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0)
|
||||
RETURNING preview_event_rowid
|
||||
`
|
||||
recalculateRoomPreviewEventQuery = `
|
||||
SELECT rowid
|
||||
FROM event
|
||||
WHERE
|
||||
room_id = $1
|
||||
AND (type IN ('m.room.message', 'm.sticker')
|
||||
OR (type = 'm.room.encrypted'
|
||||
AND decrypted_type IN ('m.room.message', 'm.sticker')))
|
||||
AND relation_type <> 'm.replace'
|
||||
AND redacted_by IS NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
`
|
||||
)
|
||||
|
||||
type RoomQuery struct {
|
||||
*dbutil.QueryHelper[*Room]
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) {
|
||||
return rq.QueryOne(ctx, getRoomByIDQuery, roomID)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) {
|
||||
return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error {
|
||||
return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error {
|
||||
return rq.Exec(ctx, ensureRoomExistsQuery, roomID)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error {
|
||||
return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch)
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) {
|
||||
var newPreviewRowID EventRowID
|
||||
err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
} else if err == nil {
|
||||
previewChanged = newPreviewRowID == rowID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) {
|
||||
err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID)
|
||||
return
|
||||
}
|
||||
|
||||
type NameQuality int
|
||||
|
||||
const (
|
||||
NameQualityNil NameQuality = iota
|
||||
NameQualityParticipants
|
||||
NameQualityCanonicalAlias
|
||||
NameQualityExplicit
|
||||
)
|
||||
|
||||
const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete"
|
||||
|
||||
type Room struct {
|
||||
ID id.RoomID `json:"room_id"`
|
||||
CreationContent *event.CreateEventContent `json:"creation_content,omitempty"`
|
||||
|
||||
Name *string `json:"name,omitempty"`
|
||||
NameQuality NameQuality `json:"name_quality"`
|
||||
Avatar *id.ContentURI `json:"avatar,omitempty"`
|
||||
ExplicitAvatar bool `json:"explicit_avatar"`
|
||||
Topic *string `json:"topic,omitempty"`
|
||||
CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"`
|
||||
|
||||
LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"`
|
||||
|
||||
EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"`
|
||||
HasMemberList bool `json:"has_member_list"`
|
||||
|
||||
PreviewEventRowID EventRowID `json:"preview_event_rowid"`
|
||||
SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"`
|
||||
UnreadHighlights int `json:"unread_highlights"`
|
||||
UnreadNotifications int `json:"unread_notifications"`
|
||||
UnreadMessages int `json:"unread_messages"`
|
||||
|
||||
PrevBatch string `json:"prev_batch"`
|
||||
}
|
||||
|
||||
func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) {
|
||||
if r.Name != nil && r.NameQuality >= other.NameQuality {
|
||||
other.Name = r.Name
|
||||
other.NameQuality = r.NameQuality
|
||||
hasChanges = true
|
||||
}
|
||||
if r.Avatar != nil {
|
||||
other.Avatar = r.Avatar
|
||||
other.ExplicitAvatar = r.ExplicitAvatar
|
||||
hasChanges = true
|
||||
}
|
||||
if r.Topic != nil {
|
||||
other.Topic = r.Topic
|
||||
hasChanges = true
|
||||
}
|
||||
if r.CanonicalAlias != nil {
|
||||
other.CanonicalAlias = r.CanonicalAlias
|
||||
hasChanges = true
|
||||
}
|
||||
if r.LazyLoadSummary != nil {
|
||||
other.LazyLoadSummary = r.LazyLoadSummary
|
||||
hasChanges = true
|
||||
}
|
||||
if r.EncryptionEvent != nil && other.EncryptionEvent == nil {
|
||||
other.EncryptionEvent = r.EncryptionEvent
|
||||
hasChanges = true
|
||||
}
|
||||
if r.HasMemberList && !other.HasMemberList {
|
||||
hasChanges = true
|
||||
other.HasMemberList = true
|
||||
}
|
||||
if r.PreviewEventRowID > other.PreviewEventRowID {
|
||||
other.PreviewEventRowID = r.PreviewEventRowID
|
||||
hasChanges = true
|
||||
}
|
||||
if r.SortingTimestamp.After(other.SortingTimestamp.Time) {
|
||||
other.SortingTimestamp = r.SortingTimestamp
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadHighlights != other.UnreadHighlights {
|
||||
other.UnreadHighlights = r.UnreadHighlights
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadNotifications != other.UnreadNotifications {
|
||||
other.UnreadNotifications = r.UnreadNotifications
|
||||
hasChanges = true
|
||||
}
|
||||
if r.UnreadMessages != other.UnreadMessages {
|
||||
other.UnreadMessages = r.UnreadMessages
|
||||
hasChanges = true
|
||||
}
|
||||
if r.PrevBatch != "" && other.PrevBatch == "" {
|
||||
other.PrevBatch = r.PrevBatch
|
||||
hasChanges = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Room) Scan(row dbutil.Scannable) (*Room, error) {
|
||||
var prevBatch sql.NullString
|
||||
var previewEventRowID, sortingTimestamp sql.NullInt64
|
||||
err := row.Scan(
|
||||
&r.ID,
|
||||
dbutil.JSON{Data: &r.CreationContent},
|
||||
&r.Name,
|
||||
&r.NameQuality,
|
||||
&r.Avatar,
|
||||
&r.ExplicitAvatar,
|
||||
&r.Topic,
|
||||
&r.CanonicalAlias,
|
||||
dbutil.JSON{Data: &r.LazyLoadSummary},
|
||||
dbutil.JSON{Data: &r.EncryptionEvent},
|
||||
&r.HasMemberList,
|
||||
&previewEventRowID,
|
||||
&sortingTimestamp,
|
||||
&r.UnreadHighlights,
|
||||
&r.UnreadNotifications,
|
||||
&r.UnreadMessages,
|
||||
&prevBatch,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.PrevBatch = prevBatch.String
|
||||
r.PreviewEventRowID = EventRowID(previewEventRowID.Int64)
|
||||
r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64))
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Room) sqlVariables() []any {
|
||||
return []any{
|
||||
r.ID,
|
||||
dbutil.JSONPtr(r.CreationContent),
|
||||
r.Name,
|
||||
r.NameQuality,
|
||||
r.Avatar,
|
||||
r.ExplicitAvatar,
|
||||
r.Topic,
|
||||
r.CanonicalAlias,
|
||||
dbutil.JSONPtr(r.LazyLoadSummary),
|
||||
dbutil.JSONPtr(r.EncryptionEvent),
|
||||
r.HasMemberList,
|
||||
dbutil.NumPtr(r.PreviewEventRowID),
|
||||
dbutil.UnixMilliPtr(r.SortingTimestamp.Time),
|
||||
r.UnreadHighlights,
|
||||
r.UnreadNotifications,
|
||||
r.UnreadMessages,
|
||||
dbutil.StrPtr(r.PrevBatch),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) BumpSortingTimestamp(evt *Event) bool {
|
||||
if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) {
|
||||
return false
|
||||
}
|
||||
r.SortingTimestamp = evt.Timestamp
|
||||
now := time.Now()
|
||||
if r.SortingTimestamp.After(now) {
|
||||
r.SortingTimestamp = jsontime.UM(now)
|
||||
}
|
||||
return true
|
||||
}
|
68
pkg/hicli/database/sessionrequest.go
Normal file
68
pkg/hicli/database/sessionrequest.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
putSessionRequestQueueEntry = `
|
||||
INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (session_id) DO UPDATE
|
||||
SET min_index = MIN(excluded.min_index, session_request.min_index),
|
||||
backup_checked = excluded.backup_checked OR session_request.backup_checked,
|
||||
request_sent = excluded.request_sent OR session_request.request_sent
|
||||
`
|
||||
removeSessionRequestQuery = `
|
||||
DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2
|
||||
`
|
||||
getNextSessionsToRequestQuery = `
|
||||
SELECT room_id, session_id, sender, min_index, backup_checked, request_sent
|
||||
FROM session_request
|
||||
WHERE request_sent = false OR backup_checked = false
|
||||
ORDER BY backup_checked, rowid
|
||||
LIMIT $1
|
||||
`
|
||||
)
|
||||
|
||||
type SessionRequestQuery struct {
|
||||
*dbutil.QueryHelper[*SessionRequest]
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) {
|
||||
return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count)
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error {
|
||||
return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex)
|
||||
}
|
||||
|
||||
func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error {
|
||||
return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...)
|
||||
}
|
||||
|
||||
type SessionRequest struct {
|
||||
RoomID id.RoomID
|
||||
SessionID id.SessionID
|
||||
Sender id.UserID
|
||||
MinIndex uint32
|
||||
BackupChecked bool
|
||||
RequestSent bool
|
||||
}
|
||||
|
||||
func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) {
|
||||
return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent))
|
||||
}
|
||||
|
||||
func (s *SessionRequest) sqlVariables() []any {
|
||||
return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent}
|
||||
}
|
94
pkg/hicli/database/state.go
Normal file
94
pkg/hicli/database/state.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
setCurrentStateQuery = `
|
||||
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership
|
||||
`
|
||||
addCurrentStateQuery = `
|
||||
INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
deleteCurrentStateQuery = `
|
||||
DELETE FROM current_state WHERE room_id = $1
|
||||
`
|
||||
getCurrentRoomStateQuery = `
|
||||
SELECT event.rowid, -1,
|
||||
event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM current_state cs
|
||||
JOIN event ON cs.event_rowid = event.rowid
|
||||
WHERE cs.room_id = $1
|
||||
`
|
||||
getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3`
|
||||
)
|
||||
|
||||
var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)")
|
||||
|
||||
const currentStateMassInsertBatchSize = 1000
|
||||
|
||||
type CurrentStateEntry struct {
|
||||
EventType event.Type
|
||||
StateKey string
|
||||
EventRowID EventRowID
|
||||
Membership event.Membership
|
||||
}
|
||||
|
||||
func (cse *CurrentStateEntry) GetMassInsertValues() [4]any {
|
||||
return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)}
|
||||
}
|
||||
|
||||
type CurrentStateQuery struct {
|
||||
*dbutil.QueryHelper[*Event]
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
|
||||
return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error {
|
||||
var err error
|
||||
if deleteOld {
|
||||
err = csq.Exec(ctx, deleteCurrentStateQuery, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old state: %w", err)
|
||||
}
|
||||
}
|
||||
for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) {
|
||||
query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk)
|
||||
err = csq.Exec(ctx, query, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error {
|
||||
return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership))
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) {
|
||||
return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey)
|
||||
}
|
||||
|
||||
func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) {
|
||||
return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID)
|
||||
}
|
187
pkg/hicli/database/statestore.go
Normal file
187
pkg/hicli/database/statestore.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
getMembershipQuery = `
|
||||
SELECT membership FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2
|
||||
`
|
||||
getStateEventContentQuery = `
|
||||
SELECT event.content FROM current_state cs
|
||||
LEFT JOIN event ON event.rowid = cs.event_rowid
|
||||
WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3
|
||||
`
|
||||
getRoomJoinedMembersQuery = `
|
||||
SELECT state_key FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join'
|
||||
`
|
||||
getRoomJoinedOrInvitedMembersQuery = `
|
||||
SELECT state_key FROM current_state
|
||||
WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite')
|
||||
`
|
||||
getHasFetchedMembersQuery = `
|
||||
SELECT has_member_list FROM room WHERE room_id = $1
|
||||
`
|
||||
isRoomEncryptedQuery = `
|
||||
SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1
|
||||
`
|
||||
getRoomEncryptionEventQuery = `
|
||||
SELECT room.encryption_event FROM room WHERE room_id = $1
|
||||
`
|
||||
findSharedRoomsQuery = `
|
||||
SELECT room_id FROM current_state
|
||||
WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join'
|
||||
`
|
||||
)
|
||||
|
||||
type ClientStateStore struct {
|
||||
*Database
|
||||
}
|
||||
|
||||
var _ mautrix.StateStore = (*ClientStateStore)(nil)
|
||||
var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil)
|
||||
var _ crypto.StateStore = (*ClientStateStore)(nil)
|
||||
|
||||
func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return c.IsMembership(ctx, roomID, userID, event.MembershipJoin)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
var membership event.Membership
|
||||
err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
membership = event.MembershipLeave
|
||||
}
|
||||
return slices.Contains(allowedMemberships, membership)
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
content, err := c.TryGetMember(ctx, roomID, userID)
|
||||
if content == nil {
|
||||
content = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return content, err
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
|
||||
rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
|
||||
rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList()
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) {
|
||||
//err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched)
|
||||
//if errors.Is(err, sql.ErrNoRows) {
|
||||
// err = nil
|
||||
//}
|
||||
//return
|
||||
return false, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) {
|
||||
err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) {
|
||||
err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID).
|
||||
Scan(&dbutil.JSON{Data: &content})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
|
||||
// TODO for multiuser support, this might need to filter by the local user's membership
|
||||
rows, err := c.Query(ctx, findSharedRoomsQuery, userID)
|
||||
return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList()
|
||||
}
|
||||
|
||||
// Update methods are all intentionally no-ops as the state store wants to have the full event
|
||||
|
||||
func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {}
|
||||
|
||||
func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
|
||||
return nil
|
||||
}
|
135
pkg/hicli/database/timeline.go
Normal file
135
pkg/hicli/database/timeline.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
clearTimelineQuery = `
|
||||
DELETE FROM timeline WHERE room_id = $1
|
||||
`
|
||||
appendTimelineQuery = `
|
||||
INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING rowid, event_rowid
|
||||
`
|
||||
prependTimelineQuery = `
|
||||
INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3)
|
||||
`
|
||||
checkTimelineContainsQuery = `
|
||||
SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2)
|
||||
`
|
||||
findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline`
|
||||
getTimelineQuery = `
|
||||
SELECT event.rowid, timeline.rowid,
|
||||
event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type,
|
||||
unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type,
|
||||
megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type
|
||||
FROM timeline
|
||||
JOIN event ON event.rowid = timeline.event_rowid
|
||||
WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2)
|
||||
ORDER BY timeline.rowid DESC
|
||||
LIMIT $3
|
||||
`
|
||||
)
|
||||
|
||||
type TimelineRowID int64
|
||||
|
||||
type TimelineRowTuple struct {
|
||||
Timeline TimelineRowID `json:"timeline_rowid"`
|
||||
Event EventRowID `json:"event_rowid"`
|
||||
}
|
||||
|
||||
var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) {
|
||||
err = row.Scan(&trt.Timeline, &trt.Event)
|
||||
return
|
||||
})
|
||||
|
||||
func (trt TimelineRowTuple) GetMassInsertValues() [2]any {
|
||||
return [2]any{trt.Timeline, trt.Event}
|
||||
}
|
||||
|
||||
var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)")
|
||||
var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)")
|
||||
|
||||
type TimelineQuery struct {
|
||||
*dbutil.QueryHelper[*Event]
|
||||
|
||||
minRowID TimelineRowID
|
||||
minRowIDFound bool
|
||||
prependLock sync.Mutex
|
||||
}
|
||||
|
||||
// Clear clears the timeline of a given room.
|
||||
func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error {
|
||||
return tq.Exec(ctx, clearTimelineQuery, roomID)
|
||||
}
|
||||
|
||||
func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) {
|
||||
tq.prependLock.Lock()
|
||||
defer tq.prependLock.Unlock()
|
||||
if !tq.minRowIDFound {
|
||||
err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return
|
||||
}
|
||||
if tq.minRowID >= 0 {
|
||||
// No negative row IDs exist, start at -2
|
||||
tq.minRowID = -2
|
||||
} else {
|
||||
// We fetched the lowest row ID, but we want the next available one, so decrement one
|
||||
tq.minRowID--
|
||||
}
|
||||
tq.minRowIDFound = true
|
||||
}
|
||||
startFrom = tq.minRowID
|
||||
tq.minRowID -= TimelineRowID(count)
|
||||
return
|
||||
}
|
||||
|
||||
// Prepend adds the given event row IDs to the beginning of the timeline.
|
||||
// The events must be sorted in reverse chronological order (newest event first).
|
||||
func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) {
|
||||
var startFrom TimelineRowID
|
||||
startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
prependEntries = make([]TimelineRowTuple, len(rowIDs))
|
||||
for i, rowID := range rowIDs {
|
||||
prependEntries[i] = TimelineRowTuple{
|
||||
Timeline: startFrom - TimelineRowID(i),
|
||||
Event: rowID,
|
||||
}
|
||||
}
|
||||
query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries)
|
||||
err = tq.Exec(ctx, query, params...)
|
||||
return
|
||||
}
|
||||
|
||||
// Append adds the given event row IDs to the end of the timeline.
|
||||
func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) {
|
||||
query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs)
|
||||
return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList()
|
||||
}
|
||||
|
||||
func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) {
|
||||
return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit)
|
||||
}
|
||||
|
||||
func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) {
|
||||
err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists)
|
||||
return
|
||||
}
|
255
pkg/hicli/database/upgrades/00-latest-revision.sql
Normal file
255
pkg/hicli/database/upgrades/00-latest-revision.sql
Normal file
|
@ -0,0 +1,255 @@
|
|||
-- v0 -> v3 (compatible with v1+): Latest revision
|
||||
CREATE TABLE account (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
homeserver_url TEXT NOT NULL,
|
||||
|
||||
next_batch TEXT NOT NULL
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE room (
|
||||
room_id TEXT NOT NULL PRIMARY KEY,
|
||||
creation_content TEXT,
|
||||
|
||||
name TEXT,
|
||||
name_quality INTEGER NOT NULL DEFAULT 0,
|
||||
avatar TEXT,
|
||||
explicit_avatar INTEGER NOT NULL DEFAULT 0,
|
||||
topic TEXT,
|
||||
canonical_alias TEXT,
|
||||
lazy_load_summary TEXT,
|
||||
|
||||
encryption_event TEXT,
|
||||
has_member_list INTEGER NOT NULL DEFAULT false,
|
||||
|
||||
preview_event_rowid INTEGER,
|
||||
sorting_timestamp INTEGER,
|
||||
unread_highlights INTEGER NOT NULL DEFAULT 0,
|
||||
unread_notifications INTEGER NOT NULL DEFAULT 0,
|
||||
unread_messages INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
prev_batch TEXT,
|
||||
|
||||
CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
|
||||
) STRICT;
|
||||
CREATE INDEX room_type_idx ON room (creation_content ->> 'type');
|
||||
CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC);
|
||||
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0);
|
||||
-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0);
|
||||
|
||||
CREATE TABLE account_data (
|
||||
user_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, type)
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE room_account_data (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (user_id, room_id, type),
|
||||
CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id);
|
||||
|
||||
CREATE TABLE event (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
state_key TEXT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
|
||||
content TEXT NOT NULL,
|
||||
decrypted TEXT,
|
||||
decrypted_type TEXT,
|
||||
unsigned TEXT NOT NULL,
|
||||
local_content TEXT,
|
||||
|
||||
transaction_id TEXT,
|
||||
|
||||
redacted_by TEXT,
|
||||
relates_to TEXT,
|
||||
relation_type TEXT,
|
||||
|
||||
megolm_session_id TEXT,
|
||||
decryption_error TEXT,
|
||||
send_error TEXT,
|
||||
|
||||
reactions TEXT,
|
||||
last_edit_rowid INTEGER,
|
||||
unread_type INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
CONSTRAINT event_id_unique_key UNIQUE (event_id),
|
||||
CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id),
|
||||
CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
CREATE INDEX event_room_id_idx ON event (room_id);
|
||||
CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by);
|
||||
CREATE INDEX event_relates_to_idx ON event (room_id, relates_to);
|
||||
CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id);
|
||||
|
||||
CREATE TRIGGER event_update_redacted_by
|
||||
AFTER INSERT
|
||||
ON event
|
||||
WHEN NEW.type = 'm.room.redaction'
|
||||
BEGIN
|
||||
UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts';
|
||||
END;
|
||||
|
||||
CREATE TRIGGER event_update_last_edit_when_redacted
|
||||
AFTER UPDATE
|
||||
ON event
|
||||
WHEN OLD.redacted_by IS NULL
|
||||
AND NEW.redacted_by IS NOT NULL
|
||||
AND NEW.relation_type = 'm.replace'
|
||||
AND NEW.state_key IS NULL
|
||||
BEGIN
|
||||
UPDATE event
|
||||
SET last_edit_rowid = COALESCE(
|
||||
(SELECT rowid
|
||||
FROM event edit
|
||||
WHERE edit.room_id = event.room_id
|
||||
AND edit.relates_to = event.event_id
|
||||
AND edit.relation_type = 'm.replace'
|
||||
AND edit.type = event.type
|
||||
AND edit.sender = event.sender
|
||||
AND edit.redacted_by IS NULL
|
||||
AND edit.state_key IS NULL
|
||||
ORDER BY edit.timestamp DESC
|
||||
LIMIT 1),
|
||||
0)
|
||||
WHERE event_id = NEW.relates_to
|
||||
AND last_edit_rowid = NEW.rowid
|
||||
AND state_key IS NULL
|
||||
AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'));
|
||||
END;
|
||||
|
||||
CREATE TRIGGER event_insert_update_last_edit
|
||||
AFTER INSERT
|
||||
ON event
|
||||
WHEN NEW.relation_type = 'm.replace'
|
||||
AND NEW.redacted_by IS NULL
|
||||
AND NEW.state_key IS NULL
|
||||
BEGIN
|
||||
UPDATE event
|
||||
SET last_edit_rowid = NEW.rowid
|
||||
WHERE event_id = NEW.relates_to
|
||||
AND type = NEW.type
|
||||
AND sender = NEW.sender
|
||||
AND state_key IS NULL
|
||||
AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation'))
|
||||
AND NEW.timestamp >
|
||||
COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER event_insert_fill_reactions
|
||||
AFTER INSERT
|
||||
ON event
|
||||
WHEN NEW.type = 'm.reaction'
|
||||
AND NEW.relation_type = 'm.annotation'
|
||||
AND NEW.redacted_by IS NULL
|
||||
AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
|
||||
BEGIN
|
||||
UPDATE event
|
||||
SET reactions=json_set(
|
||||
reactions,
|
||||
'$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
|
||||
coalesce(
|
||||
reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
|
||||
0
|
||||
) + 1)
|
||||
WHERE event_id = NEW.relates_to
|
||||
AND reactions IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER event_redact_fill_reactions
|
||||
AFTER UPDATE
|
||||
ON event
|
||||
WHEN NEW.type = 'm.reaction'
|
||||
AND NEW.relation_type = 'm.annotation'
|
||||
AND NEW.redacted_by IS NOT NULL
|
||||
AND OLD.redacted_by IS NULL
|
||||
AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text'
|
||||
BEGIN
|
||||
UPDATE event
|
||||
SET reactions=json_set(
|
||||
reactions,
|
||||
'$.' || json_quote(NEW.content ->> '$."m.relates_to".key'),
|
||||
coalesce(
|
||||
reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')),
|
||||
0
|
||||
) - 1)
|
||||
WHERE event_id = NEW.relates_to
|
||||
AND reactions IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TABLE cached_media (
|
||||
mxc TEXT NOT NULL PRIMARY KEY,
|
||||
event_rowid INTEGER,
|
||||
enc_file TEXT,
|
||||
file_name TEXT,
|
||||
mime_type TEXT,
|
||||
size INTEGER,
|
||||
hash BLOB,
|
||||
error TEXT,
|
||||
|
||||
CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE session_request (
|
||||
room_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
min_index INTEGER NOT NULL,
|
||||
backup_checked INTEGER NOT NULL DEFAULT false,
|
||||
request_sent INTEGER NOT NULL DEFAULT false,
|
||||
|
||||
PRIMARY KEY (session_id),
|
||||
CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
) STRICT;
|
||||
CREATE INDEX session_request_room_idx ON session_request (room_id);
|
||||
|
||||
CREATE TABLE timeline (
|
||||
rowid INTEGER PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
event_rowid INTEGER NOT NULL,
|
||||
|
||||
CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
|
||||
CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE,
|
||||
CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid)
|
||||
) STRICT;
|
||||
CREATE INDEX timeline_room_id_idx ON timeline (room_id);
|
||||
|
||||
CREATE TABLE current_state (
|
||||
room_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
state_key TEXT NOT NULL,
|
||||
event_rowid INTEGER NOT NULL,
|
||||
|
||||
membership TEXT,
|
||||
|
||||
PRIMARY KEY (room_id, event_type, state_key),
|
||||
CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE,
|
||||
CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid)
|
||||
) STRICT, WITHOUT ROWID;
|
||||
|
||||
CREATE TABLE receipt (
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
thread_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
|
||||
PRIMARY KEY (room_id, user_id, receipt_type, thread_id),
|
||||
CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE
|
||||
-- note: there's no foreign key on event ID because receipts could point at events that are too far in history.
|
||||
) STRICT;
|
2
pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql
Normal file
2
pkg/hicli/database/upgrades/02-explicit-avatar-flag.sql
Normal file
|
@ -0,0 +1,2 @@
|
|||
-- v2 (compatible with v1+): Add explicit avatar flag to rooms
|
||||
ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0;
|
6
pkg/hicli/database/upgrades/03-more-event-fields.sql
Normal file
6
pkg/hicli/database/upgrades/03-more-event-fields.sql
Normal file
|
@ -0,0 +1,6 @@
|
|||
-- v3 (compatible with v1+): Add more fields to events
|
||||
ALTER TABLE event ADD COLUMN local_content TEXT;
|
||||
ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0;
|
||||
ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0;
|
22
pkg/hicli/database/upgrades/upgrades.go
Normal file
22
pkg/hicli/database/upgrades/upgrades.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package upgrades
|
||||
|
||||
import (
|
||||
"embed"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
var Table dbutil.UpgradeTable
|
||||
|
||||
//go:embed *.sql
|
||||
var upgrades embed.FS
|
||||
|
||||
func init() {
|
||||
Table.RegisterFS(upgrades)
|
||||
}
|
209
pkg/hicli/decryptionqueue.go
Normal file
209
pkg/hicli/decryptionqueue.go
Normal file
|
@ -0,0 +1,209 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
|
||||
data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted)
|
||||
}
|
||||
|
||||
func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session")
|
||||
}
|
||||
events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption")
|
||||
return
|
||||
} else if len(events) == 0 {
|
||||
log.Trace().Msg("No events to retry decryption for")
|
||||
return
|
||||
}
|
||||
decrypted := events[:0]
|
||||
for _, evt := range events {
|
||||
if evt.Decrypted != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var mautrixEvt *event.Event
|
||||
mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session")
|
||||
} else {
|
||||
decrypted = append(decrypted, evt)
|
||||
h.postDecryptProcess(ctx, nil, evt, mautrixEvt)
|
||||
}
|
||||
}
|
||||
if len(decrypted) > 0 {
|
||||
var newPreview database.EventRowID
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
for _, evt := range decrypted {
|
||||
err = h.DB.Event.UpdateDecrypted(ctx, evt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err)
|
||||
}
|
||||
if evt.CanUseForPreview() {
|
||||
var previewChanged bool
|
||||
previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err)
|
||||
} else if previewChanged {
|
||||
newPreview = evt.RowID
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to save decrypted events")
|
||||
} else {
|
||||
h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) WakeupRequestQueue() {
|
||||
select {
|
||||
case h.requestQueueWakeup <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) RunRequestQueue(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
log.Info().Msg("Starting key request queue")
|
||||
defer func() {
|
||||
log.Info().Msg("Stopping key request queue")
|
||||
}()
|
||||
for {
|
||||
err := h.FetchKeysForOutdatedUsers(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to fetch outdated device lists for tracked users")
|
||||
}
|
||||
madeRequests, err := h.RequestQueuedSessions(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to handle session request queue")
|
||||
} else if madeRequests {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-h.requestQueueWakeup:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) {
|
||||
defer doneFunc()
|
||||
log := zerolog.Ctx(ctx)
|
||||
if !req.BackupChecked {
|
||||
sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("session_id", req.SessionID).
|
||||
Msg("Failed to fetch session from key backup")
|
||||
|
||||
// TODO should this have retries instead of just storing it's checked?
|
||||
req.BackupChecked = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup")
|
||||
}
|
||||
} else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex {
|
||||
req.BackupChecked = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Stringer("session_id", req.SessionID).
|
||||
Msg("Found session with sufficiently low first known index, removing from queue")
|
||||
err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex())
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{
|
||||
h.Account.UserID: {"*"},
|
||||
req.Sender: {"*"},
|
||||
})
|
||||
//var err error
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Stringer("session_id", req.SessionID).
|
||||
Msg("Failed to send key request")
|
||||
} else {
|
||||
log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request")
|
||||
req.RequestSent = true
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const MaxParallelRequests = 5
|
||||
|
||||
func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) {
|
||||
sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get next events to decrypt: %w", err)
|
||||
} else if len(sessions) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(sessions))
|
||||
for _, req := range sessions {
|
||||
go h.requestQueuedSession(ctx, req, wg.Done)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return true, err
|
||||
}
|
||||
|
||||
func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error {
|
||||
outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(outdatedUsers) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO backoff for users that fail to be fetched?
|
||||
return nil
|
||||
}
|
60
pkg/hicli/events.go
Normal file
60
pkg/hicli/events.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
type SyncRoom struct {
|
||||
Meta *database.Room `json:"meta"`
|
||||
Timeline []database.TimelineRowTuple `json:"timeline"`
|
||||
State map[event.Type]map[string]database.EventRowID `json:"state"`
|
||||
Events []*database.Event `json:"events"`
|
||||
Reset bool `json:"reset"`
|
||||
Notifications []SyncNotification `json:"notifications"`
|
||||
}
|
||||
|
||||
type SyncNotification struct {
|
||||
RowID database.EventRowID `json:"event_rowid"`
|
||||
Sound bool `json:"sound"`
|
||||
}
|
||||
|
||||
type SyncComplete struct {
|
||||
Rooms map[id.RoomID]*SyncRoom `json:"rooms"`
|
||||
}
|
||||
|
||||
func (c *SyncComplete) IsEmpty() bool {
|
||||
return len(c.Rooms) == 0
|
||||
}
|
||||
|
||||
type EventsDecrypted struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"`
|
||||
Events []*database.Event `json:"events"`
|
||||
}
|
||||
|
||||
type Typing struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
event.TypingEventContent
|
||||
}
|
||||
|
||||
type SendComplete struct {
|
||||
Event *database.Event `json:"event"`
|
||||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
type ClientState struct {
|
||||
IsLoggedIn bool `json:"is_logged_in"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
UserID id.UserID `json:"user_id,omitempty"`
|
||||
DeviceID id.DeviceID `json:"device_id,omitempty"`
|
||||
HomeserverURL string `json:"homeserver_url,omitempty"`
|
||||
}
|
251
pkg/hicli/hicli.go
Normal file
251
pkg/hicli/hicli.go
Normal file
|
@ -0,0 +1,251 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix.
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
type HiClient struct {
|
||||
DB *database.Database
|
||||
Account *database.Account
|
||||
Client *mautrix.Client
|
||||
Crypto *crypto.OlmMachine
|
||||
CryptoStore *crypto.SQLCryptoStore
|
||||
ClientStore *database.ClientStateStore
|
||||
Log zerolog.Logger
|
||||
|
||||
Verified bool
|
||||
|
||||
KeyBackupVersion id.KeyBackupVersion
|
||||
KeyBackupKey *backup.MegolmBackupKey
|
||||
|
||||
PushRules atomic.Pointer[pushrules.PushRuleset]
|
||||
|
||||
EventHandler func(evt any)
|
||||
|
||||
firstSyncReceived bool
|
||||
syncingID int
|
||||
syncLock sync.Mutex
|
||||
stopSync atomic.Pointer[context.CancelFunc]
|
||||
encryptLock sync.Mutex
|
||||
|
||||
requestQueueWakeup chan struct{}
|
||||
|
||||
jsonRequestsLock sync.Mutex
|
||||
jsonRequests map[int64]context.CancelCauseFunc
|
||||
|
||||
paginationInterrupterLock sync.Mutex
|
||||
paginationInterrupter map[id.RoomID]context.CancelCauseFunc
|
||||
}
|
||||
|
||||
var ErrTimelineReset = errors.New("got limited timeline sync response")
|
||||
|
||||
func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient {
|
||||
if cryptoDB == nil {
|
||||
cryptoDB = rawDB
|
||||
}
|
||||
if rawDB.Owner == "" {
|
||||
rawDB.Owner = "hicli"
|
||||
rawDB.IgnoreForeignTables = true
|
||||
}
|
||||
if rawDB.Log == nil {
|
||||
rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger())
|
||||
}
|
||||
db := database.New(rawDB)
|
||||
c := &HiClient{
|
||||
DB: db,
|
||||
Log: log,
|
||||
|
||||
requestQueueWakeup: make(chan struct{}, 1),
|
||||
jsonRequests: make(map[int64]context.CancelCauseFunc),
|
||||
paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc),
|
||||
|
||||
EventHandler: evtHandler,
|
||||
}
|
||||
c.ClientStore = &database.ClientStateStore{Database: db}
|
||||
c.Client = &mautrix.Client{
|
||||
UserAgent: mautrix.DefaultUserAgent,
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
// This needs to be relatively high to allow initial syncs
|
||||
ResponseHeaderTimeout: 180 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
Syncer: (*hiSyncer)(c),
|
||||
Store: (*hiStore)(c),
|
||||
StateStore: c.ClientStore,
|
||||
Log: log.With().Str("component", "mautrix client").Logger(),
|
||||
}
|
||||
c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey)
|
||||
cryptoLog := log.With().Str("component", "crypto").Logger()
|
||||
c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore)
|
||||
c.Crypto.SessionReceived = c.handleReceivedMegolmSession
|
||||
c.Crypto.DisableRatchetTracking = true
|
||||
c.Crypto.DisableDecryptKeyFetching = true
|
||||
c.Client.Crypto = (*hiCryptoHelper)(c)
|
||||
return c
|
||||
}
|
||||
|
||||
func (h *HiClient) IsLoggedIn() bool {
|
||||
return h.Account != nil
|
||||
}
|
||||
|
||||
func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error {
|
||||
if expectedAccount != nil && userID != expectedAccount.UserID {
|
||||
panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID"))
|
||||
}
|
||||
err := h.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade hicli db: %w", err)
|
||||
}
|
||||
err = h.CryptoStore.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade crypto db: %w", err)
|
||||
}
|
||||
account, err := h.DB.Account.Get(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if account == nil && expectedAccount != nil {
|
||||
err = h.DB.Account.Put(ctx, expectedAccount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account = expectedAccount
|
||||
} else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID {
|
||||
return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID)
|
||||
}
|
||||
if account != nil {
|
||||
zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials")
|
||||
h.Account = account
|
||||
h.CryptoStore.AccountID = account.UserID.String()
|
||||
h.CryptoStore.DeviceID = account.DeviceID
|
||||
h.Client.UserID = account.UserID
|
||||
h.Client.DeviceID = account.DeviceID
|
||||
h.Client.AccessToken = account.AccessToken
|
||||
h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.CheckServerVersions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm machine: %w", err)
|
||||
}
|
||||
|
||||
h.Verified, err = h.checkIsCurrentDeviceVerified(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status")
|
||||
if h.Verified {
|
||||
err = h.loadPrivateKeys(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go h.Sync()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrFailedToCheckServerVersions = errors.New("failed to check server versions")
|
||||
var ErrOutdatedServer = errors.New("homeserver is outdated")
|
||||
var MinimumSpecVersion = mautrix.SpecV11
|
||||
|
||||
func (h *HiClient) CheckServerVersions(ctx context.Context) error {
|
||||
versions, err := h.Client.Versions(ctx)
|
||||
if err != nil {
|
||||
return exerrors.NewDualError(ErrFailedToCheckServerVersions, err)
|
||||
} else if !versions.Contains(MinimumSpecVersion) {
|
||||
return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) IsSyncing() bool {
|
||||
return h.stopSync.Load() != nil
|
||||
}
|
||||
|
||||
func (h *HiClient) Sync() {
|
||||
h.Client.StopSync()
|
||||
if fn := h.stopSync.Load(); fn != nil {
|
||||
(*fn)()
|
||||
}
|
||||
h.syncLock.Lock()
|
||||
defer h.syncLock.Unlock()
|
||||
h.syncingID++
|
||||
syncingID := h.syncingID
|
||||
log := h.Log.With().
|
||||
Str("action", "sync").
|
||||
Int("sync_id", syncingID).
|
||||
Logger()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
h.stopSync.Store(&cancel)
|
||||
go h.RunRequestQueue(h.Log.WithContext(ctx))
|
||||
go h.LoadPushRules(h.Log.WithContext(ctx))
|
||||
ctx = log.WithContext(ctx)
|
||||
log.Info().Msg("Starting syncing")
|
||||
err := h.Client.SyncWithContext(ctx)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
log.Err(err).Msg("Fatal error in syncer")
|
||||
} else {
|
||||
log.Info().Msg("Syncing stopped")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) LoadPushRules(ctx context.Context) {
|
||||
rules, err := h.Client.GetPushRules(ctx)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules")
|
||||
return
|
||||
}
|
||||
h.PushRules.Store(rules)
|
||||
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch")
|
||||
}
|
||||
|
||||
func (h *HiClient) Stop() {
|
||||
h.Client.StopSync()
|
||||
if fn := h.stopSync.Swap(nil); fn != nil {
|
||||
(*fn)()
|
||||
}
|
||||
h.syncLock.Lock()
|
||||
//lint:ignore SA2001 just acquire the lock to make sure Sync is done
|
||||
h.syncLock.Unlock()
|
||||
err := h.DB.Close()
|
||||
if err != nil {
|
||||
h.Log.Err(err).Msg("Failed to close database cleanly")
|
||||
}
|
||||
}
|
110
pkg/hicli/hitest/hitest.go
Normal file
110
pkg/hicli/hitest/hitest.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/chzyer/readline"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
_ "go.mau.fi/util/dbutil/litestream"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"go.mau.fi/zeroconfig"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli"
|
||||
)
|
||||
|
||||
var writerTypeReadline zeroconfig.WriterType = "hitest_readline"
|
||||
|
||||
func main() {
|
||||
hicli.InitialDeviceDisplayName = "mautrix hitest"
|
||||
rl := exerrors.Must(readline.New("> "))
|
||||
defer func() {
|
||||
_ = rl.Close()
|
||||
}()
|
||||
zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) {
|
||||
return rl.Stdout(), nil
|
||||
})
|
||||
debug := zerolog.DebugLevel
|
||||
log := exerrors.Must((&zeroconfig.Config{
|
||||
MinLevel: &debug,
|
||||
Writers: []zeroconfig.WriterConfig{{
|
||||
Type: writerTypeReadline,
|
||||
Format: zeroconfig.LogFormatPrettyColored,
|
||||
}},
|
||||
}).Compile())
|
||||
exzerolog.SetupDefaults(log)
|
||||
|
||||
rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal"))
|
||||
ctx := log.WithContext(context.Background())
|
||||
cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) {
|
||||
_, _ = fmt.Fprintf(rl, "Received event of type %T\n", a)
|
||||
switch evt := a.(type) {
|
||||
case *hicli.SyncComplete:
|
||||
for _, room := range evt.Rooms {
|
||||
name := "name unset"
|
||||
if room.Meta.Name != nil {
|
||||
name = *room.Meta.Name
|
||||
}
|
||||
_, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID)
|
||||
_, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp)
|
||||
_, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset)
|
||||
}
|
||||
case *hicli.EventsDecrypted:
|
||||
for _, decrypted := range evt.Events {
|
||||
_, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted)
|
||||
}
|
||||
if evt.PreviewEventRowID != 0 {
|
||||
_, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID)
|
||||
}
|
||||
case *hicli.Typing:
|
||||
_, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs)
|
||||
}
|
||||
})
|
||||
userID, _ := cli.DB.Account.GetFirstUserID(ctx)
|
||||
exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil))
|
||||
if !cli.IsLoggedIn() {
|
||||
rl.SetPrompt("User ID: ")
|
||||
userID := id.UserID(exerrors.Must(rl.Readline()))
|
||||
_, serverName := exerrors.Must2(userID.Parse())
|
||||
discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName))
|
||||
password := exerrors.Must(rl.ReadPassword("Password: "))
|
||||
recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: "))
|
||||
exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode)))
|
||||
}
|
||||
rl.SetPrompt("> ")
|
||||
|
||||
for {
|
||||
line, err := rl.Readline()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 0 {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(fields[0]) {
|
||||
case "send":
|
||||
resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{
|
||||
Body: strings.Join(fields[2:], " "),
|
||||
MsgType: event.MsgText,
|
||||
})
|
||||
_, _ = fmt.Fprintln(rl, err)
|
||||
_, _ = fmt.Fprintf(rl, "%+v\n", resp)
|
||||
}
|
||||
}
|
||||
cli.Stop()
|
||||
}
|
493
pkg/hicli/html.go
Normal file
493
pkg/hicli/html.go
Normal file
|
@ -0,0 +1,493 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"mvdan.cc/xurls/v2"
|
||||
)
|
||||
|
||||
func tagIsAllowed(tag atom.Atom) bool {
|
||||
switch tag {
|
||||
case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P,
|
||||
atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong,
|
||||
atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody,
|
||||
atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img,
|
||||
atom.Details, atom.Summary:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isSelfClosing(tag atom.Atom) bool {
|
||||
switch tag {
|
||||
case atom.Img, atom.Br, atom.Hr:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`)
|
||||
var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`)
|
||||
|
||||
// This is approximately a mirror of web/src/util/mediasize.ts in gomuks
|
||||
func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) {
|
||||
if widthInt <= 0 || heightInt <= 0 {
|
||||
return
|
||||
}
|
||||
width = float64(widthInt)
|
||||
height = float64(heightInt)
|
||||
const imageContainerWidth float64 = 320
|
||||
const imageContainerHeight float64 = 240
|
||||
const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight
|
||||
if width > imageContainerWidth || height > imageContainerHeight {
|
||||
aspectRatio := width / height
|
||||
if aspectRatio > imageContainerAspectRatio {
|
||||
width = imageContainerWidth
|
||||
height = imageContainerWidth / aspectRatio
|
||||
} else if aspectRatio < imageContainerAspectRatio {
|
||||
width = imageContainerHeight * aspectRatio
|
||||
height = imageContainerHeight
|
||||
} else {
|
||||
width = imageContainerWidth
|
||||
height = imageContainerHeight
|
||||
}
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "src":
|
||||
src = attr.Val
|
||||
case "alt":
|
||||
alt = attr.Val
|
||||
case "title":
|
||||
title = attr.Val
|
||||
case "data-mx-emoticon":
|
||||
isCustomEmoji = true
|
||||
case "width":
|
||||
width, _ = strconv.Atoi(attr.Val)
|
||||
case "height":
|
||||
height, _ = strconv.Atoi(attr.Val)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "data-mx-bg-color":
|
||||
if allowedColorRegex.MatchString(attr.Val) {
|
||||
bgColor = attr.Val
|
||||
}
|
||||
case "data-mx-color", "color":
|
||||
if allowedColorRegex.MatchString(attr.Val) {
|
||||
textColor = attr.Val
|
||||
}
|
||||
case "data-mx-spoiler":
|
||||
spoiler = attr.Val
|
||||
isSpoiler = true
|
||||
case "data-mx-maths":
|
||||
maths = attr.Val
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseAAttributes(attrs []html.Attribute) (href string) {
|
||||
for _, attr := range attrs {
|
||||
switch attr.Key {
|
||||
case "href":
|
||||
href = strings.TrimSpace(attr.Val)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool {
|
||||
switch tag {
|
||||
case atom.Ol:
|
||||
switch attr.Key {
|
||||
case "start":
|
||||
_, err := strconv.Atoi(attr.Val)
|
||||
return err == nil
|
||||
}
|
||||
case atom.Code:
|
||||
switch attr.Key {
|
||||
case "class":
|
||||
return languageRegex.MatchString(attr.Val)
|
||||
}
|
||||
case atom.Div:
|
||||
switch attr.Key {
|
||||
case "data-mx-maths":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them.
|
||||
var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`)
|
||||
|
||||
func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) {
|
||||
for i, item := range items {
|
||||
if item[0] >= minIndex {
|
||||
return i, item[0], item[1], true
|
||||
}
|
||||
}
|
||||
return -1, -1, -1, false
|
||||
}
|
||||
|
||||
func writeMention(w *strings.Builder, mention []byte) {
|
||||
uri := &id.MatrixURI{
|
||||
Sigil1: rune(mention[0]),
|
||||
MXID1: string(mention[1:]),
|
||||
}
|
||||
w.WriteString(`<a`)
|
||||
writeAttribute(w, "href", uri.String())
|
||||
writeAttribute(w, "class", matrixURIClassName(uri)+" hicli-matrix-uri-plaintext")
|
||||
w.WriteByte('>')
|
||||
writeEscapedBytes(w, mention)
|
||||
w.WriteString("</a>")
|
||||
}
|
||||
|
||||
func writeURL(w *strings.Builder, addr []byte) {
|
||||
parsedURL, err := url.Parse(string(addr))
|
||||
if err != nil {
|
||||
writeEscapedBytes(w, addr)
|
||||
return
|
||||
}
|
||||
if parsedURL.Scheme == "" {
|
||||
parsedURL.Scheme = "https"
|
||||
}
|
||||
w.WriteString(`<a target="_blank" rel="noreferrer noopener"`)
|
||||
writeAttribute(w, "href", parsedURL.String())
|
||||
w.WriteByte('>')
|
||||
writeEscapedBytes(w, addr)
|
||||
w.WriteString("</a>")
|
||||
}
|
||||
|
||||
func linkifyAndWriteBytes(w *strings.Builder, s []byte) {
|
||||
mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1)
|
||||
urls := xurls.Relaxed().FindAllIndex(s, -1)
|
||||
minIndex := 0
|
||||
for {
|
||||
mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex)
|
||||
urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex)
|
||||
if hasMention && (!hasURL || nextMentionStart <= nextURLStart) {
|
||||
writeEscapedBytes(w, s[minIndex:nextMentionStart])
|
||||
writeMention(w, s[nextMentionStart:nextMentionEnd])
|
||||
minIndex = nextMentionEnd
|
||||
mentions = mentions[mentionIdx:]
|
||||
} else if hasURL && (!hasMention || nextURLStart < nextMentionStart) {
|
||||
writeEscapedBytes(w, s[minIndex:nextURLStart])
|
||||
writeURL(w, s[nextURLStart:nextURLEnd])
|
||||
minIndex = nextURLEnd
|
||||
urls = urls[urlIdx:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
writeEscapedBytes(w, s[minIndex:])
|
||||
}
|
||||
|
||||
const escapedChars = "&'<>\"\r"
|
||||
|
||||
func writeEscapedBytes(w *strings.Builder, s []byte) {
|
||||
i := bytes.IndexAny(s, escapedChars)
|
||||
for i != -1 {
|
||||
w.Write(s[:i])
|
||||
var esc string
|
||||
switch s[i] {
|
||||
case '&':
|
||||
esc = "&"
|
||||
case '\'':
|
||||
// "'" is shorter than "'" and apos was not in HTML until HTML5.
|
||||
esc = "'"
|
||||
case '<':
|
||||
esc = "<"
|
||||
case '>':
|
||||
esc = ">"
|
||||
case '"':
|
||||
// """ is shorter than """.
|
||||
esc = """
|
||||
case '\r':
|
||||
esc = " "
|
||||
default:
|
||||
panic("unrecognized escape character")
|
||||
}
|
||||
s = s[i+1:]
|
||||
w.WriteString(esc)
|
||||
i = bytes.IndexAny(s, escapedChars)
|
||||
}
|
||||
w.Write(s)
|
||||
}
|
||||
|
||||
func writeEscapedString(w *strings.Builder, s string) {
|
||||
i := strings.IndexAny(s, escapedChars)
|
||||
for i != -1 {
|
||||
w.WriteString(s[:i])
|
||||
var esc string
|
||||
switch s[i] {
|
||||
case '&':
|
||||
esc = "&"
|
||||
case '\'':
|
||||
// "'" is shorter than "'" and apos was not in HTML until HTML5.
|
||||
esc = "'"
|
||||
case '<':
|
||||
esc = "<"
|
||||
case '>':
|
||||
esc = ">"
|
||||
case '"':
|
||||
// """ is shorter than """.
|
||||
esc = """
|
||||
case '\r':
|
||||
esc = " "
|
||||
default:
|
||||
panic("unrecognized escape character")
|
||||
}
|
||||
s = s[i+1:]
|
||||
w.WriteString(esc)
|
||||
i = strings.IndexAny(s, escapedChars)
|
||||
}
|
||||
w.WriteString(s)
|
||||
}
|
||||
|
||||
func writeAttribute(w *strings.Builder, key, value string) {
|
||||
w.WriteByte(' ')
|
||||
w.WriteString(key)
|
||||
w.WriteString(`="`)
|
||||
writeEscapedString(w, value)
|
||||
w.WriteByte('"')
|
||||
}
|
||||
|
||||
func matrixURIClassName(uri *id.MatrixURI) string {
|
||||
switch uri.Sigil1 {
|
||||
case '@':
|
||||
return "hicli-matrix-uri hicli-matrix-uri-user"
|
||||
case '#':
|
||||
return "hicli-matrix-uri hicli-matrix-uri-room-alias"
|
||||
case '!':
|
||||
if uri.Sigil2 == '$' {
|
||||
return "hicli-matrix-uri hicli-matrix-uri-event-id"
|
||||
}
|
||||
return "hicli-matrix-uri hicli-matrix-uri-room-id"
|
||||
default:
|
||||
return "hicli-matrix-uri hicli-matrix-uri-unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func writeA(w *strings.Builder, attr []html.Attribute) {
|
||||
w.WriteString("<a")
|
||||
href := parseAAttributes(attr)
|
||||
if href == "" {
|
||||
return
|
||||
}
|
||||
parsedURL, err := url.Parse(href)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
newTab := true
|
||||
switch parsedURL.Scheme {
|
||||
case "bitcoin", "ftp", "geo", "http", "im", "irc", "ircs", "magnet", "mailto",
|
||||
"mms", "news", "nntp", "openpgp4fpr", "sip", "sftp", "sms", "smsto", "ssh",
|
||||
"tel", "urn", "webcal", "wtai", "xmpp":
|
||||
// allowed
|
||||
case "https":
|
||||
if parsedURL.Host == "matrix.to" {
|
||||
uri, err := id.ProcessMatrixToURL(parsedURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
href = uri.String()
|
||||
newTab = false
|
||||
writeAttribute(w, "class", matrixURIClassName(uri))
|
||||
}
|
||||
case "matrix":
|
||||
uri, err := id.ProcessMatrixURI(parsedURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
href = uri.String()
|
||||
newTab = false
|
||||
writeAttribute(w, "class", matrixURIClassName(uri))
|
||||
case "mxc":
|
||||
mxc := id.ContentURIString(href).ParseOrIgnore()
|
||||
if !mxc.IsValid() {
|
||||
return
|
||||
}
|
||||
href = fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID)
|
||||
default:
|
||||
return
|
||||
}
|
||||
writeAttribute(w, "href", href)
|
||||
if newTab {
|
||||
writeAttribute(w, "target", "_blank")
|
||||
writeAttribute(w, "rel", "noreferrer noopener")
|
||||
}
|
||||
}
|
||||
|
||||
var HTMLSanitizerImgSrcTemplate = "mxc://%s/%s"
|
||||
|
||||
func writeImg(w *strings.Builder, attr []html.Attribute) {
|
||||
src, alt, title, isCustomEmoji, width, height := parseImgAttributes(attr)
|
||||
w.WriteString("<img")
|
||||
writeAttribute(w, "alt", alt)
|
||||
if title != "" {
|
||||
writeAttribute(w, "title", title)
|
||||
}
|
||||
mxc := id.ContentURIString(src).ParseOrIgnore()
|
||||
if !mxc.IsValid() {
|
||||
return
|
||||
}
|
||||
writeAttribute(w, "src", fmt.Sprintf(HTMLSanitizerImgSrcTemplate, mxc.Homeserver, mxc.FileID))
|
||||
writeAttribute(w, "loading", "lazy")
|
||||
if isCustomEmoji {
|
||||
writeAttribute(w, "class", "hicli-custom-emoji")
|
||||
} else if cWidth, cHeight, sizeOK := calculateMediaSize(width, height); sizeOK {
|
||||
writeAttribute(w, "class", "hicli-sized-inline-img")
|
||||
writeAttribute(w, "style", fmt.Sprintf("width: %.2fpx; height: %.2fpx;", cWidth, cHeight))
|
||||
} else {
|
||||
writeAttribute(w, "class", "hicli-sizeless-inline-img")
|
||||
}
|
||||
}
|
||||
|
||||
func writeSpan(w *strings.Builder, attr []html.Attribute) {
|
||||
bgColor, textColor, spoiler, _, isSpoiler := parseSpanAttributes(attr)
|
||||
if isSpoiler && spoiler != "" {
|
||||
w.WriteString(`<span class="spoiler-reason">`)
|
||||
w.WriteString(spoiler)
|
||||
w.WriteString(" </span>")
|
||||
}
|
||||
w.WriteByte('<')
|
||||
w.WriteString("span")
|
||||
if isSpoiler {
|
||||
writeAttribute(w, "class", "hicli-spoiler")
|
||||
}
|
||||
var style string
|
||||
if bgColor != "" {
|
||||
style += fmt.Sprintf("background-color: %s;", bgColor)
|
||||
}
|
||||
if textColor != "" {
|
||||
style += fmt.Sprintf("color: %s;", textColor)
|
||||
}
|
||||
if style != "" {
|
||||
writeAttribute(w, "style", style)
|
||||
}
|
||||
}
|
||||
|
||||
type tagStack []atom.Atom
|
||||
|
||||
func (ts *tagStack) contains(tags ...atom.Atom) bool {
|
||||
for i := len(*ts) - 1; i >= 0; i-- {
|
||||
for _, tag := range tags {
|
||||
if (*ts)[i] == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ts *tagStack) push(tag atom.Atom) {
|
||||
*ts = append(*ts, tag)
|
||||
}
|
||||
|
||||
func (ts *tagStack) pop(tag atom.Atom) bool {
|
||||
if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag {
|
||||
*ts = (*ts)[:len(*ts)-1]
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeAndLinkifyHTML(body string) (string, error) {
|
||||
tz := html.NewTokenizer(strings.NewReader(body))
|
||||
var built strings.Builder
|
||||
ts := make(tagStack, 2)
|
||||
Loop:
|
||||
for {
|
||||
switch tz.Next() {
|
||||
case html.ErrorToken:
|
||||
err := tz.Err()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break Loop
|
||||
}
|
||||
return "", err
|
||||
case html.StartTagToken, html.SelfClosingTagToken:
|
||||
token := tz.Token()
|
||||
if !tagIsAllowed(token.DataAtom) {
|
||||
continue
|
||||
}
|
||||
tagIsSelfClosing := isSelfClosing(token.DataAtom)
|
||||
if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing {
|
||||
continue
|
||||
}
|
||||
switch token.DataAtom {
|
||||
case atom.A:
|
||||
writeA(&built, token.Attr)
|
||||
case atom.Img:
|
||||
writeImg(&built, token.Attr)
|
||||
case atom.Span, atom.Font:
|
||||
writeSpan(&built, token.Attr)
|
||||
default:
|
||||
built.WriteByte('<')
|
||||
built.WriteString(token.Data)
|
||||
for _, attr := range token.Attr {
|
||||
if attributeIsAllowed(token.DataAtom, attr) {
|
||||
writeAttribute(&built, attr.Key, attr.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
built.WriteByte('>')
|
||||
if !tagIsSelfClosing {
|
||||
ts.push(token.DataAtom)
|
||||
}
|
||||
case html.EndTagToken:
|
||||
tagName, _ := tz.TagName()
|
||||
tag := atom.Lookup(tagName)
|
||||
if tagIsAllowed(tag) && ts.pop(tag) {
|
||||
built.WriteString("</")
|
||||
built.Write(tagName)
|
||||
built.WriteByte('>')
|
||||
}
|
||||
case html.TextToken:
|
||||
if ts.contains(atom.Pre, atom.Code, atom.A) {
|
||||
writeEscapedBytes(&built, tz.Text())
|
||||
} else {
|
||||
linkifyAndWriteBytes(&built, tz.Text())
|
||||
}
|
||||
case html.DoctypeToken, html.CommentToken:
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
slices.Reverse(ts)
|
||||
for _, t := range ts {
|
||||
built.WriteString("</")
|
||||
built.WriteString(t.String())
|
||||
built.WriteByte('>')
|
||||
}
|
||||
return built.String(), nil
|
||||
}
|
179
pkg/hicli/json-commands.go
Normal file
179
pkg/hicli/json-commands.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) {
|
||||
switch req.Command {
|
||||
case "get_state":
|
||||
return h.State(), nil
|
||||
case "cancel":
|
||||
return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) {
|
||||
h.jsonRequestsLock.Lock()
|
||||
cancelTarget, ok := h.jsonRequests[params.RequestID]
|
||||
h.jsonRequestsLock.Unlock()
|
||||
if ok {
|
||||
return false, nil
|
||||
}
|
||||
if params.Reason == "" {
|
||||
cancelTarget(nil)
|
||||
} else {
|
||||
cancelTarget(errors.New(params.Reason))
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
case "send_message":
|
||||
return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) {
|
||||
return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions)
|
||||
})
|
||||
case "send_event":
|
||||
return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) {
|
||||
return h.Send(ctx, params.RoomID, params.EventType, params.Content)
|
||||
})
|
||||
case "mark_read":
|
||||
return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) {
|
||||
return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType)
|
||||
})
|
||||
case "set_typing":
|
||||
return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) {
|
||||
return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond)
|
||||
})
|
||||
case "get_event":
|
||||
return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) {
|
||||
return h.GetEvent(ctx, params.RoomID, params.EventID)
|
||||
})
|
||||
case "get_events_by_rowids":
|
||||
return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) {
|
||||
return h.GetEventsByRowIDs(ctx, params.RowIDs)
|
||||
})
|
||||
case "get_room_state":
|
||||
return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) {
|
||||
return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch)
|
||||
})
|
||||
case "paginate":
|
||||
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
|
||||
return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit)
|
||||
})
|
||||
case "paginate_server":
|
||||
return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) {
|
||||
return h.PaginateServer(ctx, params.RoomID, params.Limit)
|
||||
})
|
||||
case "ensure_group_session_shared":
|
||||
return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) {
|
||||
return true, h.EnsureGroupSessionShared(ctx, params.RoomID)
|
||||
})
|
||||
case "login":
|
||||
return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) {
|
||||
return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password)
|
||||
})
|
||||
case "verify":
|
||||
return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) {
|
||||
return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey)
|
||||
})
|
||||
case "discover_homeserver":
|
||||
return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) {
|
||||
_, homeserver, err := params.UserID.Parse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mautrix.DiscoverClientAPI(ctx, homeserver)
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown command %q", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) {
|
||||
var input T
|
||||
err = json.Unmarshal(data, &input)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return fn(&input)
|
||||
}
|
||||
|
||||
type cancelRequestParams struct {
|
||||
RequestID int64 `json:"request_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type sendMessageParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
Text string `json:"text"`
|
||||
MediaPath string `json:"media_path"`
|
||||
ReplyTo id.EventID `json:"reply_to"`
|
||||
Mentions *event.Mentions `json:"mentions"`
|
||||
}
|
||||
|
||||
type sendEventParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
EventType event.Type `json:"type"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
type markReadParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
EventID id.EventID `json:"event_id"`
|
||||
ReceiptType event.ReceiptType `json:"receipt_type"`
|
||||
}
|
||||
|
||||
type setTypingParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
type getEventParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
EventID id.EventID `json:"event_id"`
|
||||
}
|
||||
|
||||
type getEventsByRowIDsParams struct {
|
||||
RowIDs []database.EventRowID `json:"row_ids"`
|
||||
}
|
||||
|
||||
type getRoomStateParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
Refetch bool `json:"refetch"`
|
||||
FetchMembers bool `json:"fetch_members"`
|
||||
}
|
||||
|
||||
type ensureGroupSessionSharedParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
}
|
||||
|
||||
type loginParams struct {
|
||||
HomeserverURL string `json:"homeserver_url"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type verifyParams struct {
|
||||
RecoveryKey string `json:"recovery_key"`
|
||||
}
|
||||
|
||||
type discoverHomeserverParams struct {
|
||||
UserID id.UserID `json:"user_id"`
|
||||
}
|
||||
|
||||
type paginateParams struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
MaxTimelineID database.TimelineRowID `json:"max_timeline_id"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
119
pkg/hicli/json.go
Normal file
119
pkg/hicli/json.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"go.mau.fi/util/exerrors"
|
||||
)
|
||||
|
||||
type JSONCommand struct {
|
||||
Command string `json:"command"`
|
||||
RequestID int64 `json:"request_id"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type JSONEventHandler func(*JSONCommand)
|
||||
|
||||
var outgoingEventCounter atomic.Int64
|
||||
|
||||
func (jeh JSONEventHandler) HandleEvent(evt any) {
|
||||
var command string
|
||||
switch evt.(type) {
|
||||
case *SyncComplete:
|
||||
command = "sync_complete"
|
||||
case *EventsDecrypted:
|
||||
command = "events_decrypted"
|
||||
case *Typing:
|
||||
command = "typing"
|
||||
case *SendComplete:
|
||||
command = "send_complete"
|
||||
case *ClientState:
|
||||
command = "client_state"
|
||||
default:
|
||||
panic(fmt.Errorf("unknown event type %T", evt))
|
||||
}
|
||||
data, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to marshal event %T: %w", evt, err))
|
||||
}
|
||||
jeh(&JSONCommand{
|
||||
Command: command,
|
||||
RequestID: -outgoingEventCounter.Add(1),
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HiClient) State() *ClientState {
|
||||
state := &ClientState{}
|
||||
if acc := h.Account; acc != nil {
|
||||
state.IsLoggedIn = true
|
||||
state.UserID = acc.UserID
|
||||
state.DeviceID = acc.DeviceID
|
||||
state.HomeserverURL = acc.HomeserverURL
|
||||
state.IsVerified = h.Verified
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func (h *HiClient) dispatchCurrentState() {
|
||||
h.EventHandler(h.State())
|
||||
}
|
||||
|
||||
func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand {
|
||||
if req.Command == "ping" {
|
||||
return &JSONCommand{
|
||||
Command: "pong",
|
||||
RequestID: req.RequestID,
|
||||
}
|
||||
}
|
||||
log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger()
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
defer func() {
|
||||
cancel(nil)
|
||||
h.jsonRequestsLock.Lock()
|
||||
delete(h.jsonRequests, req.RequestID)
|
||||
h.jsonRequestsLock.Unlock()
|
||||
}()
|
||||
ctx = log.WithContext(ctx)
|
||||
h.jsonRequestsLock.Lock()
|
||||
h.jsonRequests[req.RequestID] = cancel
|
||||
h.jsonRequestsLock.Unlock()
|
||||
resp, err := h.handleJSONCommand(ctx, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
causeErr := context.Cause(ctx)
|
||||
if causeErr != ctx.Err() {
|
||||
err = fmt.Errorf("%w: %w", err, causeErr)
|
||||
}
|
||||
}
|
||||
return &JSONCommand{
|
||||
Command: "error",
|
||||
RequestID: req.RequestID,
|
||||
Data: exerrors.Must(json.Marshal(err.Error())),
|
||||
}
|
||||
}
|
||||
var respData json.RawMessage
|
||||
respData, err = json.Marshal(resp)
|
||||
if err != nil {
|
||||
return &JSONCommand{
|
||||
Command: "error",
|
||||
RequestID: req.RequestID,
|
||||
Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))),
|
||||
}
|
||||
}
|
||||
return &JSONCommand{
|
||||
Command: "response",
|
||||
RequestID: req.RequestID,
|
||||
Data: respData,
|
||||
}
|
||||
}
|
88
pkg/hicli/login.go
Normal file
88
pkg/hicli/login.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
var InitialDeviceDisplayName = "mautrix hiclient"
|
||||
|
||||
func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error {
|
||||
var err error
|
||||
h.Client.HomeserverURL, err = url.Parse(homeserverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.Login(ctx, &mautrix.ReqLogin{
|
||||
Type: mautrix.AuthTypePassword,
|
||||
Identifier: mautrix.UserIdentifier{
|
||||
Type: mautrix.IdentifierTypeUser,
|
||||
User: username,
|
||||
},
|
||||
Password: password,
|
||||
InitialDeviceDisplayName: InitialDeviceDisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error {
|
||||
err := h.CheckServerVersions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.StoreCredentials = true
|
||||
req.StoreHomeserverURL = true
|
||||
resp, err := h.Client.Login(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer h.dispatchCurrentState()
|
||||
h.Account = &database.Account{
|
||||
UserID: resp.UserID,
|
||||
DeviceID: resp.DeviceID,
|
||||
AccessToken: resp.AccessToken,
|
||||
HomeserverURL: h.Client.HomeserverURL.String(),
|
||||
}
|
||||
h.CryptoStore.AccountID = resp.UserID.String()
|
||||
h.CryptoStore.DeviceID = resp.DeviceID
|
||||
err = h.DB.Account.Put(ctx, h.Account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm machine: %w", err)
|
||||
}
|
||||
err = h.Crypto.ShareKeys(ctx, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch own devices: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error {
|
||||
err := h.LoginPassword(ctx, homeserverURL, username, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.VerifyWithRecoveryKey(ctx, recoveryKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
245
pkg/hicli/paginate.go
Normal file
245
pkg/hicli/paginate.go
Normal file
|
@ -0,0 +1,245 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress")
|
||||
|
||||
func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) {
|
||||
events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(events) == 0 {
|
||||
return events, nil
|
||||
}
|
||||
firstRoomID := events[0].RoomID
|
||||
allInSameRoom := true
|
||||
for _, evt := range events {
|
||||
h.ReprocessExistingEvent(ctx, evt)
|
||||
if evt.RoomID != firstRoomID {
|
||||
allInSameRoom = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInSameRoom {
|
||||
err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events)
|
||||
if err != nil {
|
||||
return events, fmt.Errorf("failed to fill last edit row IDs: %w", err)
|
||||
}
|
||||
err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events)
|
||||
if err != nil {
|
||||
return events, fmt.Errorf("failed to fill reaction counts: %w", err)
|
||||
}
|
||||
} else {
|
||||
// TODO slow path where events are collected and filling is done one room at a time?
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) {
|
||||
if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil {
|
||||
return nil, fmt.Errorf("failed to get event from database: %w", err)
|
||||
} else if evt != nil {
|
||||
h.ReprocessExistingEvent(ctx, evt)
|
||||
return evt, nil
|
||||
} else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil {
|
||||
return nil, fmt.Errorf("failed to get event from server: %w", err)
|
||||
} else {
|
||||
return h.processEvent(ctx, serverEvt, nil, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) {
|
||||
var evts []*event.Event
|
||||
if refetch {
|
||||
resp, err := h.Client.StateAsArray(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refetch state: %w", err)
|
||||
}
|
||||
evts = resp
|
||||
} else if fetchMembers {
|
||||
resp, err := h.Client.Members(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch members: %w", err)
|
||||
}
|
||||
evts = resp.Chunk
|
||||
}
|
||||
if evts != nil {
|
||||
err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
room, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room from database: %w", err)
|
||||
}
|
||||
updatedRoom := &database.Room{
|
||||
ID: room.ID,
|
||||
HasMemberList: true,
|
||||
}
|
||||
entries := make([]*database.CurrentStateEntry, len(evts))
|
||||
for i, evt := range evts {
|
||||
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process event %s: %w", evt.ID, err)
|
||||
}
|
||||
entries[i] = &database.CurrentStateEntry{
|
||||
EventType: evt.Type,
|
||||
StateKey: *evt.StateKey,
|
||||
EventRowID: dbEvt.RowID,
|
||||
}
|
||||
if evt.Type == event.StateMember {
|
||||
entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string))
|
||||
} else {
|
||||
processImportantEvent(ctx, evt, room, updatedRoom)
|
||||
}
|
||||
}
|
||||
err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
||||
if roomChanged {
|
||||
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save room data: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return h.DB.CurrentState.GetAll(ctx, roomID)
|
||||
}
|
||||
|
||||
type PaginationResponse struct {
|
||||
Events []*database.Event `json:"events"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) {
|
||||
evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(evts) > 0 {
|
||||
for _, evt := range evts {
|
||||
h.ReprocessExistingEvent(ctx, evt)
|
||||
}
|
||||
return &PaginationResponse{Events: evts, HasMore: true}, nil
|
||||
} else {
|
||||
return h.PaginateServer(ctx, roomID, limit)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) {
|
||||
ctx, cancel := context.WithCancelCause(ctx)
|
||||
h.paginationInterrupterLock.Lock()
|
||||
if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating {
|
||||
h.paginationInterrupterLock.Unlock()
|
||||
return nil, ErrPaginationAlreadyInProgress
|
||||
}
|
||||
h.paginationInterrupter[roomID] = cancel
|
||||
h.paginationInterrupterLock.Unlock()
|
||||
defer func() {
|
||||
h.paginationInterrupterLock.Lock()
|
||||
delete(h.paginationInterrupter, roomID)
|
||||
h.paginationInterrupterLock.Unlock()
|
||||
}()
|
||||
|
||||
room, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get room from database: %w", err)
|
||||
} else if room.PrevBatch == database.PrevBatchPaginationComplete {
|
||||
return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil
|
||||
}
|
||||
resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get messages from server: %w", err)
|
||||
}
|
||||
events := make([]*database.Event, len(resp.Chunk))
|
||||
if resp.End == "" {
|
||||
resp.End = database.PrevBatchPaginationComplete
|
||||
}
|
||||
if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 {
|
||||
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set prev_batch: %w", err)
|
||||
}
|
||||
return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil
|
||||
}
|
||||
wakeupSessionRequests := false
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
if err = ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
eventRowIDs := make([]database.EventRowID, len(resp.Chunk))
|
||||
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
||||
iOffset := 0
|
||||
for i, evt := range resp.Chunk {
|
||||
dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil {
|
||||
return fmt.Errorf("failed to check if event exists in timeline: %w", err)
|
||||
} else if exists {
|
||||
zerolog.Ctx(ctx).Warn().
|
||||
Int64("row_id", int64(dbEvt.RowID)).
|
||||
Str("event_id", dbEvt.ID.String()).
|
||||
Msg("Event already exists in timeline, skipping")
|
||||
iOffset++
|
||||
continue
|
||||
}
|
||||
events[i-iOffset] = dbEvt
|
||||
eventRowIDs[i-iOffset] = events[i-iOffset].RowID
|
||||
}
|
||||
if iOffset >= len(events) {
|
||||
events = events[:0]
|
||||
return nil
|
||||
}
|
||||
events = events[:len(events)-iOffset]
|
||||
eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset]
|
||||
wakeupSessionRequests = len(decryptionQueue) > 0
|
||||
for _, entry := range decryptionQueue {
|
||||
err = h.DB.SessionRequest.Put(ctx, entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
||||
}
|
||||
}
|
||||
err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fill last edit row IDs: %w", err)
|
||||
}
|
||||
err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set prev_batch: %w", err)
|
||||
}
|
||||
var tuples []database.TimelineRowTuple
|
||||
tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepend events to timeline: %w", err)
|
||||
}
|
||||
for i, evt := range events {
|
||||
evt.TimelineRowID = tuples[i].Timeline
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil && wakeupSessionRequests {
|
||||
h.WakeupRequestQueue()
|
||||
}
|
||||
return &PaginationResponse{Events: events, HasMore: true}, err
|
||||
}
|
80
pkg/hicli/pushrules.go
Normal file
80
pkg/hicli/pushrules.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
type pushRoom struct {
|
||||
ctx context.Context
|
||||
roomID id.RoomID
|
||||
h *HiClient
|
||||
ll *mautrix.LazyLoadSummary
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetOwnDisplayname() string {
|
||||
// TODO implement
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetMemberCount() int {
|
||||
if p.ll == nil {
|
||||
room, err := p.h.DB.Room.Get(p.ctx, p.roomID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(p.ctx).Err(err).
|
||||
Stringer("room_id", p.roomID).
|
||||
Msg("Failed to get room by ID in push rule evaluator")
|
||||
} else if room != nil {
|
||||
p.ll = room.LazyLoadSummary
|
||||
}
|
||||
}
|
||||
if p.ll != nil && p.ll.JoinedMemberCount != nil {
|
||||
return *p.ll.JoinedMemberCount
|
||||
}
|
||||
// TODO query db?
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *pushRoom) GetEvent(id id.EventID) *event.Event {
|
||||
evt, err := p.h.DB.Event.GetByID(p.ctx, id)
|
||||
if err != nil {
|
||||
zerolog.Ctx(p.ctx).Err(err).
|
||||
Stringer("event_id", id).
|
||||
Msg("Failed to get event by ID in push rule evaluator")
|
||||
}
|
||||
return evt.AsRawMautrix()
|
||||
}
|
||||
|
||||
var _ pushrules.EventfulRoom = (*pushRoom)(nil)
|
||||
|
||||
func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType {
|
||||
should := h.PushRules.Load().GetMatchingRule(&pushRoom{
|
||||
ctx: ctx,
|
||||
roomID: evt.RoomID,
|
||||
h: h,
|
||||
ll: llSummary,
|
||||
}, evt).GetActions().Should()
|
||||
if should.Notify {
|
||||
baseType |= database.UnreadTypeNotify
|
||||
}
|
||||
if should.Highlight {
|
||||
baseType |= database.UnreadTypeHighlight
|
||||
}
|
||||
if should.PlaySound {
|
||||
baseType |= database.UnreadTypeSound
|
||||
}
|
||||
return baseType
|
||||
}
|
287
pkg/hicli/send.go
Normal file
287
pkg/hicli/send.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yuin/goldmark"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"go.mau.fi/util/ptr"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/format"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
"go.mau.fi/gomuks/pkg/rainbow"
|
||||
)
|
||||
|
||||
var (
|
||||
rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension))
|
||||
)
|
||||
|
||||
func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) {
|
||||
var content event.MessageEventContent
|
||||
if strings.HasPrefix(text, "/rainbow ") {
|
||||
text = strings.TrimPrefix(text, "/rainbow ")
|
||||
content = format.RenderMarkdownCustom(text, rainbowWithHTML)
|
||||
content.FormattedBody = rainbow.ApplyColor(content.FormattedBody)
|
||||
} else if strings.HasPrefix(text, "/plain ") {
|
||||
text = strings.TrimPrefix(text, "/plain ")
|
||||
content = format.RenderMarkdown(text, false, false)
|
||||
} else if strings.HasPrefix(text, "/html ") {
|
||||
text = strings.TrimPrefix(text, "/html ")
|
||||
content = format.RenderMarkdown(text, false, true)
|
||||
} else {
|
||||
content = format.RenderMarkdown(text, true, false)
|
||||
}
|
||||
if mentions != nil {
|
||||
content.Mentions.Room = mentions.Room
|
||||
for _, userID := range mentions.UserIDs {
|
||||
if userID != h.Account.UserID {
|
||||
content.Mentions.Add(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if replyTo != "" {
|
||||
content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo)
|
||||
}
|
||||
return h.Send(ctx, roomID, event.EventMessage, &content)
|
||||
}
|
||||
|
||||
func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error {
|
||||
content := &mautrix.ReqSetReadMarkers{
|
||||
FullyRead: eventID,
|
||||
}
|
||||
if receiptType == event.ReceiptTypeRead {
|
||||
content.Read = eventID
|
||||
} else if receiptType == event.ReceiptTypeReadPrivate {
|
||||
content.ReadPrivate = eventID
|
||||
} else {
|
||||
return fmt.Errorf("invalid receipt type: %v", receiptType)
|
||||
}
|
||||
err := h.Client.SetReadMarkers(ctx, roomID, content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark event as read: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error {
|
||||
_, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) {
|
||||
roomMeta, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get room metadata: %w", err)
|
||||
} else if roomMeta == nil {
|
||||
return nil, fmt.Errorf("unknown room")
|
||||
}
|
||||
var decryptedType event.Type
|
||||
var decryptedContent json.RawMessage
|
||||
var megolmSessionID id.SessionID
|
||||
if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction {
|
||||
decryptedType = evtType
|
||||
decryptedContent, err = json.Marshal(content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal event content: %w", err)
|
||||
}
|
||||
encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt event: %w", err)
|
||||
}
|
||||
megolmSessionID = encryptedContent.SessionID
|
||||
content = encryptedContent
|
||||
evtType = event.EventEncrypted
|
||||
}
|
||||
mainContent, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal event content: %w", err)
|
||||
}
|
||||
txnID := "hicli-" + h.Client.TxnID()
|
||||
relatesTo, relationType := database.GetRelatesToFromBytes(mainContent)
|
||||
dbEvt := &database.Event{
|
||||
RoomID: roomID,
|
||||
ID: id.EventID(fmt.Sprintf("~%s", txnID)),
|
||||
Sender: h.Account.UserID,
|
||||
Type: evtType.Type,
|
||||
Timestamp: jsontime.UnixMilliNow(),
|
||||
Content: mainContent,
|
||||
Decrypted: decryptedContent,
|
||||
DecryptedType: decryptedType.Type,
|
||||
Unsigned: []byte("{}"),
|
||||
TransactionID: txnID,
|
||||
RelatesTo: relatesTo,
|
||||
RelationType: relationType,
|
||||
MegolmSessionID: megolmSessionID,
|
||||
DecryptionError: "",
|
||||
SendError: "not sent",
|
||||
Reactions: map[string]int{},
|
||||
LastEditRowID: ptr.Ptr(database.EventRowID(0)),
|
||||
}
|
||||
_, err = h.DB.Event.Insert(ctx, dbEvt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert event into database: %w", err)
|
||||
}
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
go func() {
|
||||
err := h.SetTyping(ctx, roomID, 0)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message")
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
defer func() {
|
||||
h.EventHandler(&SendComplete{
|
||||
Event: dbEvt,
|
||||
Error: err,
|
||||
})
|
||||
}()
|
||||
var resp *mautrix.RespSendEvent
|
||||
resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{
|
||||
Timestamp: dbEvt.Timestamp.UnixMilli(),
|
||||
TransactionID: txnID,
|
||||
DontEncrypt: true,
|
||||
})
|
||||
if err != nil {
|
||||
dbEvt.SendError = err.Error()
|
||||
err = fmt.Errorf("failed to send event: %w", err)
|
||||
err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError)
|
||||
if err2 != nil {
|
||||
zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err).
|
||||
Msg("Failed to update send error in database after sending failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
dbEvt.ID = resp.EventID
|
||||
err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update event ID in database: %w", err)
|
||||
}
|
||||
}()
|
||||
return dbEvt, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
h.encryptLock.Lock()
|
||||
defer h.encryptLock.Unlock()
|
||||
encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content)
|
||||
if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) {
|
||||
if err = h.shareGroupSession(ctx, room); err != nil {
|
||||
err = fmt.Errorf("failed to share group session: %w", err)
|
||||
} else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error {
|
||||
h.encryptLock.Lock()
|
||||
defer h.encryptLock.Unlock()
|
||||
if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil {
|
||||
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
||||
} else if session != nil && session.Shared && !session.Expired() {
|
||||
return nil
|
||||
} else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil {
|
||||
return fmt.Errorf("failed to get room metadata: %w", err)
|
||||
} else if roomMeta == nil {
|
||||
return fmt.Errorf("unknown room")
|
||||
} else {
|
||||
return h.shareGroupSession(ctx, roomMeta)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error {
|
||||
if room.HasMemberList {
|
||||
return nil
|
||||
}
|
||||
resp, err := h.Client.Members(ctx, room.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room member list: %w", err)
|
||||
}
|
||||
err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
entries := make([]*database.CurrentStateEntry, len(resp.Chunk))
|
||||
for i, evt := range resp.Chunk {
|
||||
dbEvt, err := h.processEvent(ctx, evt, nil, nil, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entries[i] = &database.CurrentStateEntry{
|
||||
EventType: evt.Type,
|
||||
StateKey: *evt.StateKey,
|
||||
EventRowID: dbEvt.RowID,
|
||||
Membership: event.Membership(evt.Content.Raw["membership"].(string)),
|
||||
}
|
||||
}
|
||||
err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.DB.Room.Upsert(ctx, &database.Room{
|
||||
ID: room.ID,
|
||||
HasMemberList: true,
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process room member list: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error {
|
||||
err := h.loadMembers(ctx, room)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID)
|
||||
var users []id.UserID
|
||||
if shareToInvited {
|
||||
users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID)
|
||||
} else {
|
||||
users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil {
|
||||
return fmt.Errorf("failed to share group session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool {
|
||||
historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "")
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event")
|
||||
return false
|
||||
}
|
||||
mautrixEvt := historyVisibility.AsRawMautrix()
|
||||
err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event")
|
||||
return false
|
||||
}
|
||||
hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent)
|
||||
if !ok {
|
||||
zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event")
|
||||
return false
|
||||
}
|
||||
return hv.HistoryVisibility == event.HistoryVisibilityInvited ||
|
||||
hv.HistoryVisibility == event.HistoryVisibilityShared ||
|
||||
hv.HistoryVisibility == event.HistoryVisibilityWorldReadable
|
||||
}
|
850
pkg/hicli/sync.go
Normal file
850
pkg/hicli/sync.go
Normal file
|
@ -0,0 +1,850 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.mau.fi/util/exzerolog"
|
||||
"go.mau.fi/util/jsontime"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
type syncContext struct {
|
||||
shouldWakeupRequestQueue bool
|
||||
|
||||
evt *SyncComplete
|
||||
}
|
||||
|
||||
func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
log := zerolog.Ctx(ctx)
|
||||
postponedToDevices := resp.ToDevice.Events[:0]
|
||||
for _, evt := range resp.ToDevice.Events {
|
||||
evt.Type.Class = event.ToDeviceEventType
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
log.Warn().Err(err).
|
||||
Stringer("event_type", &evt.Type).
|
||||
Stringer("sender", evt.Sender).
|
||||
Msg("Failed to parse to-device event, skipping")
|
||||
continue
|
||||
}
|
||||
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.EncryptedEventContent:
|
||||
h.Crypto.HandleEncryptedEvent(ctx, evt)
|
||||
case *event.RoomKeyWithheldEventContent:
|
||||
h.Crypto.HandleRoomKeyWithheld(ctx, content)
|
||||
default:
|
||||
postponedToDevices = append(postponedToDevices, evt)
|
||||
}
|
||||
}
|
||||
resp.ToDevice.Events = postponedToDevices
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
||||
h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
||||
go h.asyncPostProcessSyncResponse(ctx, resp, since)
|
||||
syncCtx := ctx.Value(syncContextKey).(*syncContext)
|
||||
if syncCtx.shouldWakeupRequestQueue {
|
||||
h.WakeupRequestQueue()
|
||||
}
|
||||
h.firstSyncReceived = true
|
||||
if !syncCtx.evt.IsEmpty() {
|
||||
h.EventHandler(syncCtx.evt)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) {
|
||||
for _, evt := range resp.ToDevice.Events {
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.SecretRequestEventContent:
|
||||
h.Crypto.HandleSecretRequest(ctx, evt.Sender, content)
|
||||
case *event.RoomKeyRequestEventContent:
|
||||
h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
if len(resp.DeviceLists.Changed) > 0 {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)).
|
||||
Msg("Marking changed device lists for tracked users as outdated")
|
||||
err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark changed device lists as outdated: %w", err)
|
||||
}
|
||||
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
||||
}
|
||||
|
||||
for _, evt := range resp.AccountData.Events {
|
||||
evt.Type.Class = event.AccountDataEventType
|
||||
err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
||||
}
|
||||
if evt.Type == event.AccountDataPushRules {
|
||||
err = evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync")
|
||||
} else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok {
|
||||
h.PushRules.Store(pushRules.Ruleset)
|
||||
zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync")
|
||||
}
|
||||
}
|
||||
}
|
||||
for roomID, room := range resp.Rooms.Join {
|
||||
err := h.processSyncJoinedRoom(ctx, roomID, room)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process joined room %s: %w", roomID, err)
|
||||
}
|
||||
}
|
||||
for roomID, room := range resp.Rooms.Leave {
|
||||
err := h.processSyncLeftRoom(ctx, roomID, room)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process left room %s: %w", roomID, err)
|
||||
}
|
||||
}
|
||||
h.Account.NextBatch = resp.NextBatch
|
||||
err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save next_batch: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) {
|
||||
receiptList := make([]*database.Receipt, 0)
|
||||
var newOwnReceipts []id.EventID
|
||||
for eventID, receipts := range *content {
|
||||
for receiptType, users := range receipts {
|
||||
for userID, receiptInfo := range users {
|
||||
if userID == h.Account.UserID {
|
||||
newOwnReceipts = append(newOwnReceipts, eventID)
|
||||
}
|
||||
receiptList = append(receiptList, &database.Receipt{
|
||||
UserID: userID,
|
||||
ReceiptType: receiptType,
|
||||
ThreadID: receiptInfo.ThreadID,
|
||||
EventID: eventID,
|
||||
Timestamp: jsontime.UM(receiptInfo.Timestamp),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return receiptList, newOwnReceipts
|
||||
}
|
||||
|
||||
type receiptsToSave struct {
|
||||
roomID id.RoomID
|
||||
receipts []*database.Receipt
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error {
|
||||
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room data: %w", err)
|
||||
} else if existingRoomData == nil {
|
||||
err = h.DB.Room.CreateRow(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ensure room row exists: %w", err)
|
||||
}
|
||||
existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()}
|
||||
}
|
||||
|
||||
for _, evt := range room.AccountData.Events {
|
||||
evt.Type.Class = event.AccountDataEventType
|
||||
evt.RoomID = roomID
|
||||
err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err)
|
||||
}
|
||||
}
|
||||
var receipts []receiptsToSave
|
||||
var newOwnReceipts []id.EventID
|
||||
for _, evt := range room.Ephemeral.Events {
|
||||
evt.Type.Class = event.EphemeralEventType
|
||||
err = evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content")
|
||||
continue
|
||||
}
|
||||
switch evt.Type {
|
||||
case event.EphemeralEventReceipt:
|
||||
var receiptsList []*database.Receipt
|
||||
receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt())
|
||||
receipts = append(receipts, receiptsToSave{roomID, receiptsList})
|
||||
case event.EphemeralEventTyping:
|
||||
go h.EventHandler(&Typing{
|
||||
RoomID: roomID,
|
||||
TypingEventContent: *evt.Content.AsTyping(),
|
||||
})
|
||||
}
|
||||
}
|
||||
err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rs := range receipts {
|
||||
err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save receipts: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error {
|
||||
existingRoomData, err := h.DB.Room.Get(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get room data: %w", err)
|
||||
} else if existingRoomData == nil {
|
||||
return nil
|
||||
}
|
||||
// TODO delete room instead of processing?
|
||||
return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil)
|
||||
}
|
||||
|
||||
func isDecryptionErrorRetryable(err error) bool {
|
||||
return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld)
|
||||
}
|
||||
|
||||
func removeReplyFallback(evt *event.Event) []byte {
|
||||
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
|
||||
return nil
|
||||
}
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if ok && content.RelatesTo.GetReplyTo() != "" {
|
||||
prevFormattedBody := content.FormattedBody
|
||||
content.RemoveReplyFallback()
|
||||
if content.FormattedBody != prevFormattedBody {
|
||||
bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody)
|
||||
bytes, err2 := sjson.SetBytes(bytes, "body", content.Body)
|
||||
if err == nil && err2 == nil {
|
||||
return bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) {
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
withoutFallback := removeReplyFallback(decrypted)
|
||||
if withoutFallback != nil {
|
||||
return decrypted, withoutFallback, decrypted.Type.Type, nil
|
||||
}
|
||||
return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) addMediaCache(
|
||||
ctx context.Context,
|
||||
eventRowID database.EventRowID,
|
||||
uri id.ContentURIString,
|
||||
file *event.EncryptedFileInfo,
|
||||
info *event.FileInfo,
|
||||
fileName string,
|
||||
) {
|
||||
parsedMXC := uri.ParseOrIgnore()
|
||||
if !parsedMXC.IsValid() {
|
||||
return
|
||||
}
|
||||
cm := &database.CachedMedia{
|
||||
MXC: parsedMXC,
|
||||
EventRowID: eventRowID,
|
||||
FileName: fileName,
|
||||
}
|
||||
if file != nil {
|
||||
cm.EncFile = &file.EncryptedFile
|
||||
}
|
||||
if info != nil {
|
||||
cm.MimeType = info.MimeType
|
||||
}
|
||||
err := h.DB.CachedMedia.Put(ctx, cm)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("mxc", parsedMXC).
|
||||
Int64("event_rowid", int64(eventRowID)).
|
||||
Msg("Failed to add cached media entry")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) {
|
||||
switch evt.Type {
|
||||
case event.EventMessage, event.EventSticker:
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if content.File != nil {
|
||||
h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName())
|
||||
} else if content.URL != "" {
|
||||
h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName())
|
||||
}
|
||||
if content.GetInfo().ThumbnailFile != nil {
|
||||
h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "")
|
||||
} else if content.GetInfo().ThumbnailURL != "" {
|
||||
h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "")
|
||||
}
|
||||
case event.StateRoomAvatar:
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.addMediaCache(ctx, rowID, content.URL, nil, nil, "")
|
||||
case event.StateMember:
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.MemberEventContent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent {
|
||||
if evt.Type != event.EventMessage && evt.Type != event.EventSticker {
|
||||
return nil
|
||||
}
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
content, ok := evt.Content.Parsed.(*event.MessageEventContent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if dbEvt.RelationType == event.RelReplace && content.NewContent != nil {
|
||||
content = content.NewContent
|
||||
}
|
||||
if content != nil {
|
||||
var sanitizedHTML string
|
||||
if content.Format == event.FormatHTML {
|
||||
sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody)
|
||||
} else {
|
||||
var builder strings.Builder
|
||||
linkifyAndWriteBytes(&builder, []byte(content.Body))
|
||||
sanitizedHTML = builder.String()
|
||||
}
|
||||
return &database.LocalContent{SanitizedHTML: sanitizedHTML, HTMLVersion: CurrentHTMLSanitizerVersion}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const CurrentHTMLSanitizerVersion = 1
|
||||
|
||||
func (h *HiClient) ReprocessExistingEvent(ctx context.Context, evt *database.Event) {
|
||||
if evt.Type != event.EventMessage.Type || evt.LocalContent == nil || evt.LocalContent.HTMLVersion >= CurrentHTMLSanitizerVersion {
|
||||
return
|
||||
}
|
||||
evt.LocalContent = h.calculateLocalContent(ctx, evt, evt.AsRawMautrix())
|
||||
err := h.DB.Event.UpdateLocalContent(ctx, evt)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("event_id", evt.ID).
|
||||
Msg("Failed to update local content")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) {
|
||||
if dbEvt.RowID != 0 {
|
||||
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
||||
}
|
||||
if evt.Sender != h.Account.UserID {
|
||||
dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt)
|
||||
}
|
||||
dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt)
|
||||
}
|
||||
|
||||
func (h *HiClient) processEvent(
|
||||
ctx context.Context,
|
||||
evt *event.Event,
|
||||
llSummary *mautrix.LazyLoadSummary,
|
||||
decryptionQueue map[id.SessionID]*database.SessionRequest,
|
||||
checkDB bool,
|
||||
) (*database.Event, error) {
|
||||
if checkDB {
|
||||
dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err)
|
||||
} else if dbEvt != nil {
|
||||
return dbEvt, nil
|
||||
}
|
||||
}
|
||||
dbEvt := database.MautrixToEvent(evt)
|
||||
contentWithoutFallback := removeReplyFallback(evt)
|
||||
if contentWithoutFallback != nil {
|
||||
dbEvt.Content = contentWithoutFallback
|
||||
}
|
||||
var decryptionErr error
|
||||
var decryptedMautrixEvt *event.Event
|
||||
if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" {
|
||||
decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt)
|
||||
if decryptionErr != nil {
|
||||
dbEvt.DecryptionError = decryptionErr.Error()
|
||||
}
|
||||
} else if evt.Type == event.EventRedaction {
|
||||
if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() {
|
||||
var err error
|
||||
evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts)
|
||||
if err != nil {
|
||||
return dbEvt, fmt.Errorf("failed to set redacts field: %w", err)
|
||||
}
|
||||
} else if evt.Redacts == "" {
|
||||
evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str)
|
||||
}
|
||||
}
|
||||
if decryptedMautrixEvt != nil {
|
||||
h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt)
|
||||
} else {
|
||||
h.postDecryptProcess(ctx, llSummary, dbEvt, evt)
|
||||
}
|
||||
_, err := h.DB.Event.Upsert(ctx, dbEvt)
|
||||
if err != nil {
|
||||
return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err)
|
||||
}
|
||||
if decryptedMautrixEvt != nil {
|
||||
h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID)
|
||||
} else {
|
||||
h.cacheMedia(ctx, evt, dbEvt.RowID)
|
||||
}
|
||||
if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) {
|
||||
req, ok := decryptionQueue[dbEvt.MegolmSessionID]
|
||||
if !ok {
|
||||
req = &database.SessionRequest{
|
||||
RoomID: evt.RoomID,
|
||||
SessionID: dbEvt.MegolmSessionID,
|
||||
Sender: evt.Sender,
|
||||
}
|
||||
}
|
||||
minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext)
|
||||
req.MinIndex = min(uint32(minIndex), req.MinIndex)
|
||||
if decryptionQueue != nil {
|
||||
decryptionQueue[dbEvt.MegolmSessionID] = req
|
||||
} else {
|
||||
err = h.DB.SessionRequest.Put(ctx, req)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).
|
||||
Stringer("session_id", dbEvt.MegolmSessionID).
|
||||
Msg("Failed to save session request")
|
||||
} else {
|
||||
h.WakeupRequestQueue()
|
||||
}
|
||||
}
|
||||
}
|
||||
return dbEvt, err
|
||||
}
|
||||
|
||||
func (h *HiClient) processStateAndTimeline(
|
||||
ctx context.Context,
|
||||
room *database.Room,
|
||||
state *mautrix.SyncEventsList,
|
||||
timeline *mautrix.SyncTimeline,
|
||||
summary *mautrix.LazyLoadSummary,
|
||||
newOwnReceipts []id.EventID,
|
||||
serverNotificationCounts *mautrix.UnreadNotificationCounts,
|
||||
) error {
|
||||
updatedRoom := &database.Room{
|
||||
ID: room.ID,
|
||||
|
||||
SortingTimestamp: room.SortingTimestamp,
|
||||
NameQuality: room.NameQuality,
|
||||
UnreadHighlights: room.UnreadHighlights,
|
||||
UnreadNotifications: room.UnreadNotifications,
|
||||
UnreadMessages: room.UnreadMessages,
|
||||
}
|
||||
if serverNotificationCounts != nil {
|
||||
updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount
|
||||
updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount
|
||||
}
|
||||
heroesChanged := false
|
||||
if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil {
|
||||
summary = room.LazyLoadSummary
|
||||
} else if room.LazyLoadSummary == nil ||
|
||||
!slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) ||
|
||||
!intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) ||
|
||||
!intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) {
|
||||
updatedRoom.LazyLoadSummary = summary
|
||||
heroesChanged = true
|
||||
}
|
||||
decryptionQueue := make(map[id.SessionID]*database.SessionRequest)
|
||||
allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events))
|
||||
newNotifications := make([]SyncNotification, 0)
|
||||
recalculatePreviewEvent := false
|
||||
addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) {
|
||||
if rowID != 0 {
|
||||
dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID)
|
||||
} else {
|
||||
dbEvt, err = h.DB.Event.GetByID(ctx, evtID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get redaction target: %w", err)
|
||||
} else if dbEvt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
allNewEvents = append(allNewEvents, dbEvt)
|
||||
return dbEvt, nil
|
||||
}
|
||||
processRedaction := func(evt *event.Event) error {
|
||||
dbEvt, err := addOldEvent(0, evt.Redacts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get redaction target: %w", err)
|
||||
}
|
||||
if dbEvt == nil {
|
||||
return nil
|
||||
}
|
||||
if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
||||
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get relation target of redaction target: %w", err)
|
||||
}
|
||||
}
|
||||
if updatedRoom.PreviewEventRowID == dbEvt.RowID {
|
||||
updatedRoom.PreviewEventRowID = 0
|
||||
recalculatePreviewEvent = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) {
|
||||
evt.RoomID = room.ID
|
||||
dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) {
|
||||
newNotifications = append(newNotifications, SyncNotification{
|
||||
RowID: dbEvt.RowID,
|
||||
Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound),
|
||||
})
|
||||
}
|
||||
if isTimeline {
|
||||
if dbEvt.CanUseForPreview() {
|
||||
updatedRoom.PreviewEventRowID = dbEvt.RowID
|
||||
recalculatePreviewEvent = false
|
||||
}
|
||||
updatedRoom.BumpSortingTimestamp(dbEvt)
|
||||
}
|
||||
if evt.StateKey != nil {
|
||||
var membership event.Membership
|
||||
if evt.Type == event.StateMember {
|
||||
membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str)
|
||||
if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) {
|
||||
heroesChanged = true
|
||||
}
|
||||
} else if evt.Type == event.StateElementFunctionalMembers {
|
||||
heroesChanged = true
|
||||
}
|
||||
err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err)
|
||||
}
|
||||
processImportantEvent(ctx, evt, room, updatedRoom)
|
||||
}
|
||||
allNewEvents = append(allNewEvents, dbEvt)
|
||||
if evt.Type == event.EventRedaction && evt.Redacts != "" {
|
||||
err = processRedaction(evt)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to process redaction: %w", err)
|
||||
}
|
||||
} else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation {
|
||||
_, err = addOldEvent(0, dbEvt.RelatesTo)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to get relation target of event: %w", err)
|
||||
}
|
||||
}
|
||||
return dbEvt.RowID, nil
|
||||
}
|
||||
changedState := make(map[event.Type]map[string]database.EventRowID)
|
||||
setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) {
|
||||
if _, ok := changedState[evtType]; !ok {
|
||||
changedState[evtType] = make(map[string]database.EventRowID)
|
||||
}
|
||||
changedState[evtType][stateKey] = rowID
|
||||
}
|
||||
for _, evt := range state.Events {
|
||||
evt.Type.Class = event.StateEventType
|
||||
rowID, err := processNewEvent(evt, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setNewState(evt.Type, *evt.StateKey, rowID)
|
||||
}
|
||||
var timelineRowTuples []database.TimelineRowTuple
|
||||
var err error
|
||||
if len(timeline.Events) > 0 {
|
||||
timelineIDs := make([]database.EventRowID, len(timeline.Events))
|
||||
readUpToIndex := -1
|
||||
for i := len(timeline.Events) - 1; i >= 0; i-- {
|
||||
if slices.Contains(newOwnReceipts, timeline.Events[i].ID) {
|
||||
readUpToIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, evt := range timeline.Events {
|
||||
if evt.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
} else {
|
||||
evt.Type.Class = event.MessageEventType
|
||||
}
|
||||
timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if evt.StateKey != nil {
|
||||
setNewState(evt.Type, *evt.StateKey, timelineIDs[i])
|
||||
}
|
||||
}
|
||||
for _, entry := range decryptionQueue {
|
||||
err = h.DB.SessionRequest.Put(ctx, entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err)
|
||||
}
|
||||
}
|
||||
if len(decryptionQueue) > 0 {
|
||||
ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true
|
||||
}
|
||||
if timeline.Limited {
|
||||
err = h.DB.Timeline.Clear(ctx, room.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear old timeline: %w", err)
|
||||
}
|
||||
updatedRoom.PrevBatch = timeline.PrevBatch
|
||||
h.paginationInterrupterLock.Lock()
|
||||
if interrupt, ok := h.paginationInterrupter[room.ID]; ok {
|
||||
interrupt(ErrTimelineReset)
|
||||
}
|
||||
h.paginationInterrupterLock.Unlock()
|
||||
}
|
||||
timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to append timeline: %w", err)
|
||||
}
|
||||
} else {
|
||||
timelineRowTuples = make([]database.TimelineRowTuple, 0)
|
||||
}
|
||||
if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 {
|
||||
updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to recalculate preview event: %w", err)
|
||||
}
|
||||
_, err = addOldEvent(updatedRoom.PreviewEventRowID, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get preview event: %w", err)
|
||||
}
|
||||
}
|
||||
// Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset
|
||||
if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil {
|
||||
name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate room name: %w", err)
|
||||
}
|
||||
updatedRoom.Name = &name
|
||||
updatedRoom.NameQuality = database.NameQualityParticipants
|
||||
if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar {
|
||||
updatedRoom.Avatar = &dmAvatarURL
|
||||
}
|
||||
}
|
||||
if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) {
|
||||
updatedRoom.PrevBatch = timeline.PrevBatch
|
||||
}
|
||||
roomChanged := updatedRoom.CheckChangesAndCopyInto(room)
|
||||
if roomChanged {
|
||||
err = h.DB.Room.Upsert(ctx, updatedRoom)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save room data: %w", err)
|
||||
}
|
||||
}
|
||||
if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 {
|
||||
ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{
|
||||
Meta: room,
|
||||
Timeline: timelineRowTuples,
|
||||
State: changedState,
|
||||
Reset: timeline.Limited,
|
||||
Events: allNewEvents,
|
||||
Notifications: newNotifications,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func joinMemberNames(names []string, totalCount int) string {
|
||||
if len(names) == 1 {
|
||||
return names[0]
|
||||
} else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) {
|
||||
return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1]
|
||||
} else {
|
||||
return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) {
|
||||
var primaryAvatarURL id.ContentURI
|
||||
if summary == nil || len(summary.Heroes) == 0 {
|
||||
return "Empty room", primaryAvatarURL, nil
|
||||
}
|
||||
var functionalMembers []id.UserID
|
||||
functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "")
|
||||
if err != nil {
|
||||
return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err)
|
||||
} else if functionalMembersEvt != nil {
|
||||
mautrixEvt := functionalMembersEvt.AsRawMautrix()
|
||||
_ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type)
|
||||
content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent)
|
||||
if ok {
|
||||
functionalMembers = content.ServiceMembers
|
||||
}
|
||||
}
|
||||
var members, leftMembers []string
|
||||
var memberCount int
|
||||
if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 {
|
||||
memberCount = *summary.JoinedMemberCount
|
||||
} else if summary.InvitedMemberCount != nil {
|
||||
memberCount = *summary.InvitedMemberCount
|
||||
}
|
||||
for _, hero := range summary.Heroes {
|
||||
if slices.Contains(functionalMembers, hero) {
|
||||
memberCount--
|
||||
continue
|
||||
} else if len(members) >= 5 {
|
||||
break
|
||||
}
|
||||
heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String())
|
||||
if err != nil {
|
||||
return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err)
|
||||
} else if heroEvt == nil {
|
||||
leftMembers = append(leftMembers, hero.String())
|
||||
continue
|
||||
}
|
||||
membership := gjson.GetBytes(heroEvt.Content, "membership").Str
|
||||
name := gjson.GetBytes(heroEvt.Content, "displayname").Str
|
||||
if name == "" {
|
||||
name = hero.String()
|
||||
}
|
||||
avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str
|
||||
if avatarURL != "" {
|
||||
primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore()
|
||||
}
|
||||
if membership == "join" || membership == "invite" {
|
||||
members = append(members, name)
|
||||
} else {
|
||||
leftMembers = append(leftMembers, name)
|
||||
}
|
||||
}
|
||||
if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() {
|
||||
primaryAvatarURL = id.ContentURI{}
|
||||
}
|
||||
if len(members) > 0 {
|
||||
return joinMemberNames(members, memberCount), primaryAvatarURL, nil
|
||||
} else if len(leftMembers) > 0 {
|
||||
return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil
|
||||
} else {
|
||||
return "Empty room", primaryAvatarURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
func intPtrEqual(a, b *int) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) {
|
||||
if evt.StateKey == nil {
|
||||
return
|
||||
}
|
||||
switch evt.Type {
|
||||
case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption:
|
||||
if *evt.StateKey != "" {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("event_type", &evt.Type).
|
||||
Stringer("event_id", evt.ID).
|
||||
Msg("Failed to parse state event, skipping")
|
||||
return
|
||||
}
|
||||
switch evt.Type {
|
||||
case event.StateCreate:
|
||||
updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent)
|
||||
case event.StateEncryption:
|
||||
newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent)
|
||||
if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm {
|
||||
updatedRoom.EncryptionEvent = newEncryption
|
||||
}
|
||||
case event.StateRoomName:
|
||||
content, ok := evt.Content.Parsed.(*event.RoomNameEventContent)
|
||||
if ok {
|
||||
updatedRoom.Name = &content.Name
|
||||
updatedRoom.NameQuality = database.NameQualityExplicit
|
||||
if content.Name == "" {
|
||||
if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" {
|
||||
updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias)
|
||||
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
||||
} else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" {
|
||||
updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias)
|
||||
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
||||
} else {
|
||||
updatedRoom.NameQuality = database.NameQualityNil
|
||||
}
|
||||
}
|
||||
}
|
||||
case event.StateCanonicalAlias:
|
||||
content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent)
|
||||
if ok {
|
||||
updatedRoom.CanonicalAlias = &content.Alias
|
||||
if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias {
|
||||
updatedRoom.Name = (*string)(&content.Alias)
|
||||
updatedRoom.NameQuality = database.NameQualityCanonicalAlias
|
||||
if content.Alias == "" {
|
||||
updatedRoom.NameQuality = database.NameQualityNil
|
||||
}
|
||||
}
|
||||
}
|
||||
case event.StateRoomAvatar:
|
||||
content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent)
|
||||
if ok {
|
||||
url, _ := content.URL.Parse()
|
||||
updatedRoom.Avatar = &url
|
||||
updatedRoom.ExplicitAvatar = true
|
||||
}
|
||||
case event.StateTopic:
|
||||
content, ok := evt.Content.Parsed.(*event.TopicEventContent)
|
||||
if ok {
|
||||
updatedRoom.Topic = &content.Topic
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
96
pkg/hicli/syncwrap.go
Normal file
96
pkg/hicli/syncwrap.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type hiSyncer HiClient
|
||||
|
||||
var _ mautrix.Syncer = (*hiSyncer)(nil)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
syncContextKey contextKey = iota
|
||||
)
|
||||
|
||||
func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
|
||||
c := (*HiClient)(h)
|
||||
ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}})
|
||||
err := c.preProcessSyncResponse(ctx, resp, since)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
return c.processSyncResponse(ctx, resp, since)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.postProcessSyncResponse(ctx, resp, since)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
|
||||
(*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second")
|
||||
return 1 * time.Second, nil
|
||||
}
|
||||
|
||||
func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
|
||||
if !h.Verified {
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
Room: mautrix.RoomFilter{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return &mautrix.Filter{
|
||||
Presence: mautrix.FilterPart{
|
||||
NotRooms: []id.RoomID{"*"},
|
||||
},
|
||||
Room: mautrix.RoomFilter{
|
||||
State: mautrix.FilterPart{
|
||||
LazyLoadMembers: true,
|
||||
},
|
||||
Timeline: mautrix.FilterPart{
|
||||
Limit: 100,
|
||||
LazyLoadMembers: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type hiStore HiClient
|
||||
|
||||
var _ mautrix.SyncStore = (*hiStore)(nil)
|
||||
|
||||
// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing
|
||||
|
||||
func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil }
|
||||
func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil }
|
||||
|
||||
func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
|
||||
// This is intentionally a no-op: we don't want to save the next batch before processing the sync
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) {
|
||||
if h.Account.UserID != userID {
|
||||
return "", fmt.Errorf("mismatching user ID")
|
||||
}
|
||||
return h.Account.NextBatch, nil
|
||||
}
|
161
pkg/hicli/verify.go
Normal file
161
pkg/hicli/verify.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package hicli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/crypto/ssss"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) {
|
||||
keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx)
|
||||
if keys == nil {
|
||||
return false, fmt.Errorf("own cross-signing keys not found")
|
||||
}
|
||||
isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err)
|
||||
}
|
||||
return isVerified, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error {
|
||||
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key backup latest version: %w", err)
|
||||
}
|
||||
h.KeyBackupVersion = latestVersion.Version
|
||||
data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err)
|
||||
}
|
||||
key, err := backup.MegolmBackupKeyFromBytes(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse megolm backup key: %w", err)
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store megolm backup key: %w", err)
|
||||
}
|
||||
h.KeyBackupKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) {
|
||||
secretData, err := h.CryptoStore.GetSecret(ctx, secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get secret %s: %w", secret, err)
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(secretData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *HiClient) loadPrivateKeys(ctx context.Context) error {
|
||||
zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys")
|
||||
masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get self-signing key: %w", err)
|
||||
}
|
||||
userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user signing key: %w", err)
|
||||
}
|
||||
err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{
|
||||
MasterKey: masterKeySeed,
|
||||
SelfSigningKey: selfSigningKeySeed,
|
||||
UserSigningKey: userSigningKeySeed,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import cross-signing private keys: %w", err)
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Msg("Loading key backup key")
|
||||
keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get megolm backup key: %w", err)
|
||||
}
|
||||
h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse megolm backup key: %w", err)
|
||||
}
|
||||
zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version")
|
||||
latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get key backup latest version: %w", err)
|
||||
}
|
||||
h.KeyBackupVersion = latestVersion.Version
|
||||
zerolog.Ctx(ctx).Debug().Msg("Secrets loaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error {
|
||||
keys := h.Crypto.CrossSigningKeys
|
||||
err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error {
|
||||
defer h.dispatchCurrentState()
|
||||
keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get default SSSS key data: %w", err)
|
||||
}
|
||||
key, err := keyData.VerifyRecoveryKey(keyID, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err)
|
||||
}
|
||||
err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own device: %w", err)
|
||||
}
|
||||
err = h.Crypto.SignOwnMasterKey(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign own master key: %w", err)
|
||||
}
|
||||
err = h.storeCrossSigningPrivateKeys(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store cross-signing private keys: %w", err)
|
||||
}
|
||||
err = h.fetchKeyBackupKey(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch key backup key: %w", err)
|
||||
}
|
||||
h.Verified = true
|
||||
if !h.IsSyncing() {
|
||||
go h.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
120
pkg/rainbow/goldmark.go
Normal file
120
pkg/rainbow/goldmark.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package rainbow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode"
|
||||
|
||||
"github.com/rivo/uniseg"
|
||||
"github.com/yuin/goldmark"
|
||||
"github.com/yuin/goldmark/ast"
|
||||
"github.com/yuin/goldmark/renderer"
|
||||
"github.com/yuin/goldmark/renderer/html"
|
||||
"github.com/yuin/goldmark/util"
|
||||
"go.mau.fi/util/random"
|
||||
)
|
||||
|
||||
// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer.
|
||||
var Extension = &extRainbow{}
|
||||
|
||||
type extRainbow struct{}
|
||||
type rainbowRenderer struct {
|
||||
HardWraps bool
|
||||
ColorID string
|
||||
}
|
||||
|
||||
var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)}
|
||||
|
||||
func (er *extRainbow) Extend(m goldmark.Markdown) {
|
||||
m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0)))
|
||||
}
|
||||
|
||||
func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
|
||||
reg.Register(ast.KindText, rb.renderText)
|
||||
reg.Register(ast.KindString, rb.renderString)
|
||||
}
|
||||
|
||||
type rainbowBufWriter struct {
|
||||
util.BufWriter
|
||||
ColorID string
|
||||
}
|
||||
|
||||
func (rbw rainbowBufWriter) WriteString(s string) (int, error) {
|
||||
i := 0
|
||||
graphemes := uniseg.NewGraphemes(s)
|
||||
for graphemes.Next() {
|
||||
runes := graphemes.Runes()
|
||||
if len(runes) == 1 && unicode.IsSpace(runes[0]) {
|
||||
i2, err := rbw.BufWriter.WriteRune(runes[0])
|
||||
i += i2
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
i2, err := fmt.Fprintf(rbw.BufWriter, "<font color=\"%s\">%s</font>", rbw.ColorID, graphemes.Str())
|
||||
i += i2
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (rbw rainbowBufWriter) Write(data []byte) (int, error) {
|
||||
return rbw.WriteString(string(data))
|
||||
}
|
||||
|
||||
func (rbw rainbowBufWriter) WriteByte(c byte) error {
|
||||
_, err := rbw.WriteRune(rune(c))
|
||||
return err
|
||||
}
|
||||
|
||||
func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) {
|
||||
if unicode.IsSpace(r) {
|
||||
return rbw.BufWriter.WriteRune(r)
|
||||
} else {
|
||||
return fmt.Fprintf(rbw.BufWriter, "<font color=\"%s\">%c</font>", rbw.ColorID, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return ast.WalkContinue, nil
|
||||
}
|
||||
n := node.(*ast.Text)
|
||||
segment := n.Segment
|
||||
if n.IsRaw() {
|
||||
html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
|
||||
} else {
|
||||
html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source))
|
||||
if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) {
|
||||
_, _ = w.WriteString("<br>\n")
|
||||
} else if n.SoftLineBreak() {
|
||||
_ = w.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return ast.WalkContinue, nil
|
||||
}
|
||||
|
||||
func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
|
||||
if !entering {
|
||||
return ast.WalkContinue, nil
|
||||
}
|
||||
n := node.(*ast.String)
|
||||
if n.IsCode() {
|
||||
_, _ = w.Write(n.Value)
|
||||
} else {
|
||||
if n.IsRaw() {
|
||||
html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value)
|
||||
} else {
|
||||
html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value)
|
||||
}
|
||||
}
|
||||
return ast.WalkContinue, nil
|
||||
}
|
56
pkg/rainbow/gradient.go
Normal file
56
pkg/rainbow/gradient.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package rainbow
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lucasb-eyer/go-colorful"
|
||||
)
|
||||
|
||||
// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go
|
||||
type GradientTable []struct {
|
||||
Col colorful.Color
|
||||
Pos float64
|
||||
}
|
||||
|
||||
func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color {
|
||||
for i := 0; i < len(gt)-1; i++ {
|
||||
c1 := gt[i]
|
||||
c2 := gt[i+1]
|
||||
if c1.Pos <= t && t <= c2.Pos {
|
||||
t := (t - c1.Pos) / (c2.Pos - c1.Pos)
|
||||
return c1.Col.BlendHcl(c2.Col, t).Clamped()
|
||||
}
|
||||
}
|
||||
return gt[len(gt)-1].Col
|
||||
}
|
||||
|
||||
var Gradient = GradientTable{
|
||||
{colorful.LinearRgb(1, 0, 0), 0 / 11.0},
|
||||
{colorful.LinearRgb(1, 0.5, 0), 1 / 11.0},
|
||||
{colorful.LinearRgb(1, 1, 0), 2 / 11.0},
|
||||
{colorful.LinearRgb(0.5, 1, 0), 3 / 11.0},
|
||||
{colorful.LinearRgb(0, 1, 0), 4 / 11.0},
|
||||
{colorful.LinearRgb(0, 1, 0.5), 5 / 11.0},
|
||||
{colorful.LinearRgb(0, 1, 1), 6 / 11.0},
|
||||
{colorful.LinearRgb(0, 0.5, 1), 7 / 11.0},
|
||||
{colorful.LinearRgb(0, 0, 1), 8 / 11.0},
|
||||
{colorful.LinearRgb(0.5, 0, 1), 9 / 11.0},
|
||||
{colorful.LinearRgb(1, 0, 1), 10 / 11.0},
|
||||
{colorful.LinearRgb(1, 0, 0.5), 11 / 11.0},
|
||||
}
|
||||
|
||||
func ApplyColor(htmlBody string) string {
|
||||
count := strings.Count(htmlBody, defaultRB.ColorID)
|
||||
i := -1
|
||||
return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string {
|
||||
i++
|
||||
return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex()
|
||||
})
|
||||
}
|
47
server.go
47
server.go
|
@ -82,11 +82,12 @@ var (
|
|||
)
|
||||
|
||||
type tokenData struct {
|
||||
Username string `json:"username"`
|
||||
Expiry jsontime.Unix `json:"expiry"`
|
||||
Username string `json:"username"`
|
||||
Expiry jsontime.Unix `json:"expiry"`
|
||||
ImageOnly bool `json:"image_only,omitempty"`
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) validateAuth(token string) bool {
|
||||
func (gmx *Gomuks) validateAuth(token string, imageOnly bool) bool {
|
||||
if len(token) > 500 {
|
||||
return false
|
||||
}
|
||||
|
@ -110,19 +111,31 @@ func (gmx *Gomuks) validateAuth(token string) bool {
|
|||
|
||||
var td tokenData
|
||||
err = json.Unmarshal(rawJSON, &td)
|
||||
return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now())
|
||||
return err == nil && td.Username == gmx.Config.Web.Username && td.Expiry.After(time.Now()) && td.ImageOnly == imageOnly
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) generateToken() (string, time.Time) {
|
||||
expiry := time.Now().Add(7 * 24 * time.Hour)
|
||||
data := exerrors.Must(json.Marshal(tokenData{
|
||||
return gmx.signToken(tokenData{
|
||||
Username: gmx.Config.Web.Username,
|
||||
Expiry: jsontime.U(expiry),
|
||||
}))
|
||||
}), expiry
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) generateImageToken() string {
|
||||
return gmx.signToken(tokenData{
|
||||
Username: gmx.Config.Web.Username,
|
||||
Expiry: jsontime.U(time.Now().Add(1 * time.Hour)),
|
||||
ImageOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) signToken(td tokenData) string {
|
||||
data := exerrors.Must(json.Marshal(td))
|
||||
hasher := hmac.New(sha256.New, []byte(gmx.Config.Web.TokenKey))
|
||||
hasher.Write(data)
|
||||
checksum := hasher.Sum(nil)
|
||||
return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum), expiry
|
||||
return base64.RawURLEncoding.EncodeToString(data) + "." + base64.RawURLEncoding.EncodeToString(checksum)
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) {
|
||||
|
@ -137,7 +150,7 @@ func (gmx *Gomuks) writeTokenCookie(w http.ResponseWriter) {
|
|||
|
||||
func (gmx *Gomuks) Authenticate(w http.ResponseWriter, r *http.Request) {
|
||||
authCookie, err := r.Cookie("gomuks_auth")
|
||||
if err == nil && gmx.validateAuth(authCookie.Value) {
|
||||
if err == nil && gmx.validateAuth(authCookie.Value, false) {
|
||||
gmx.writeTokenCookie(w)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else if username, password, ok := r.BasicAuth(); !ok {
|
||||
|
@ -164,9 +177,23 @@ func isUserFetch(header http.Header) bool {
|
|||
header.Get("Sec-Fetch-User") == "?1"
|
||||
}
|
||||
|
||||
func isImageFetch(header http.Header) bool {
|
||||
return header.Get("Sec-Fetch-Site") == "cross-site" &&
|
||||
header.Get("Sec-Fetch-Mode") == "no-cors" &&
|
||||
header.Get("Sec-Fetch-Dest") == "image"
|
||||
}
|
||||
|
||||
func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Sec-Fetch-Site") != "" && r.Header.Get("Sec-Fetch-Site") != "same-origin" && !isUserFetch(r.Header) {
|
||||
if strings.HasPrefix(r.URL.Path, "/media") &&
|
||||
isImageFetch(r.Header) &&
|
||||
gmx.validateAuth(r.URL.Query().Get("image_auth"), true) &&
|
||||
r.URL.Query().Get("encrypted") == "false" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
} else if r.Header.Get("Sec-Fetch-Site") != "" &&
|
||||
r.Header.Get("Sec-Fetch-Site") != "same-origin" &&
|
||||
!isUserFetch(r.Header) {
|
||||
hlog.FromRequest(r).Debug().
|
||||
Str("site", r.Header.Get("Sec-Fetch-Site")).
|
||||
Str("dest", r.Header.Get("Sec-Fetch-Dest")).
|
||||
|
@ -181,7 +208,7 @@ func (gmx *Gomuks) AuthMiddleware(next http.Handler) http.Handler {
|
|||
if err != nil {
|
||||
ErrMissingCookie.Write(w)
|
||||
return
|
||||
} else if !gmx.validateAuth(authCookie.Value) {
|
||||
} else if !gmx.validateAuth(authCookie.Value, false) {
|
||||
ErrInvalidCookie.Write(w)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@ function App() {
|
|||
const clientState = useEventAsState(client.state)
|
||||
;((window as unknown) as { client: Client }).client = client
|
||||
useEffect(() => {
|
||||
Notification.requestPermission()
|
||||
.then(permission => console.log("Notification permission:", permission))
|
||||
client.rpc.start()
|
||||
return () => client.rpc.stop()
|
||||
}, [client])
|
||||
|
|
|
@ -39,6 +39,8 @@ export default class Client {
|
|||
this.store.applyDecrypted(ev.data)
|
||||
} else if (ev.command === "send_complete") {
|
||||
this.store.applySendComplete(ev.data)
|
||||
} else if (ev.command === "image_auth_token") {
|
||||
this.store.imageAuthToken = ev.data
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import unhomoglyph from "unhomoglyph"
|
||||
import { getMediaURL } from "@/api/media.ts"
|
||||
import { NonNullCachedEventDispatcher } from "@/util/eventdispatcher.ts"
|
||||
import { focused } from "@/util/focus.ts"
|
||||
import type {
|
||||
ContentURI,
|
||||
ContentURI, EventRowID,
|
||||
EventsDecryptedData,
|
||||
MemDBEvent,
|
||||
RoomID,
|
||||
|
@ -34,6 +36,8 @@ export interface RoomListEntry {
|
|||
name: string
|
||||
search_name: string
|
||||
avatar?: ContentURI
|
||||
unread: number
|
||||
highlighted: boolean
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-misleading-character-class
|
||||
|
@ -49,9 +53,13 @@ export function toSearchableString(str: string): string {
|
|||
export class StateStore {
|
||||
readonly rooms: Map<RoomID, RoomStateStore> = new Map()
|
||||
readonly roomList = new NonNullCachedEventDispatcher<RoomListEntry[]>([])
|
||||
switchRoom?: (roomID: RoomID) => void
|
||||
imageAuthToken?: string
|
||||
|
||||
#roomListEntryChanged(entry: SyncRoom, oldEntry: RoomStateStore): boolean {
|
||||
return entry.meta.sorting_timestamp !== oldEntry.meta.current.sorting_timestamp ||
|
||||
entry.meta.unread_notifications !== oldEntry.meta.current.unread_notifications ||
|
||||
entry.meta.unread_highlights !== oldEntry.meta.current.unread_highlights ||
|
||||
entry.meta.preview_event_rowid !== oldEntry.meta.current.preview_event_rowid ||
|
||||
entry.events.findIndex(evt => evt.rowid === entry.meta.preview_event_rowid) !== -1
|
||||
}
|
||||
|
@ -71,6 +79,8 @@ export class StateStore {
|
|||
name,
|
||||
search_name: toSearchableString(name),
|
||||
avatar: entry.meta.avatar,
|
||||
unread: entry.meta.unread_notifications,
|
||||
highlighted: entry.meta.unread_highlights > 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,6 +100,12 @@ export class StateStore {
|
|||
if (roomListEntryChanged) {
|
||||
changedRoomListEntries.set(roomID, this.#makeRoomListEntry(data, room))
|
||||
}
|
||||
|
||||
if (Notification.permission === "granted" && !focused.current) {
|
||||
for (const notification of data.notifications) {
|
||||
this.showNotification(room, notification.event_rowid, notification.sound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let updatedRoomList: RoomListEntry[] | undefined
|
||||
|
@ -117,6 +133,40 @@ export class StateStore {
|
|||
}
|
||||
}
|
||||
|
||||
showNotification(room: RoomStateStore, rowid: EventRowID, sound: boolean) {
|
||||
const evt = room.eventsByRowID.get(rowid)
|
||||
if (!evt || typeof evt.content.body !== "string") {
|
||||
return
|
||||
}
|
||||
let body = evt.content.body
|
||||
if (body.length > 200) {
|
||||
body = body.slice(0, 150) + " […]"
|
||||
}
|
||||
const memberEvt = room.getStateEvent("m.room.member", evt.sender)
|
||||
const icon = `${getMediaURL(memberEvt?.content.avatar_url)}&image_auth=${this.imageAuthToken}`
|
||||
const roomName = room.meta.current.name ?? "Unnamed room"
|
||||
const senderName = memberEvt?.content.displayname ?? evt.sender
|
||||
const title = senderName === roomName ? senderName : `${senderName} (${roomName})`
|
||||
const notif = new Notification(title, {
|
||||
body,
|
||||
icon,
|
||||
badge: "/gomuks.png",
|
||||
// timestamp: evt.timestamp,
|
||||
// image: ...,
|
||||
tag: rowid.toString(),
|
||||
})
|
||||
notif.onclick = () => this.onClickNotification(room.roomID)
|
||||
if (sound) {
|
||||
// TODO play sound
|
||||
}
|
||||
}
|
||||
|
||||
onClickNotification(roomID: RoomID) {
|
||||
if (this.switchRoom) {
|
||||
this.switchRoom(roomID)
|
||||
}
|
||||
}
|
||||
|
||||
applySendComplete(data: SendCompleteData) {
|
||||
const room = this.rooms.get(data.event.room_id)
|
||||
if (!room) {
|
||||
|
|
|
@ -147,6 +147,7 @@ export class RoomStateStore {
|
|||
if (memEvt.last_edit) {
|
||||
memEvt.orig_content = memEvt.content
|
||||
memEvt.content = memEvt.last_edit.content["m.new_content"]
|
||||
memEvt.local_content = memEvt.last_edit.local_content
|
||||
}
|
||||
} else if (memEvt.relation_type === "m.replace" && memEvt.relates_to) {
|
||||
const editTarget = this.eventsByID.get(memEvt.relates_to)
|
||||
|
@ -154,6 +155,7 @@ export class RoomStateStore {
|
|||
editTarget.last_edit = memEvt
|
||||
editTarget.orig_content = editTarget.content
|
||||
editTarget.content = memEvt.content["m.new_content"]
|
||||
editTarget.local_content = memEvt.local_content
|
||||
}
|
||||
}
|
||||
this.eventsByRowID.set(memEvt.rowid, memEvt)
|
||||
|
|
|
@ -60,12 +60,22 @@ export interface EventsDecryptedEvent extends RPCCommand<EventsDecryptedData> {
|
|||
command: "events_decrypted"
|
||||
}
|
||||
|
||||
export interface ImageAuthTokenEvent extends RPCCommand<string> {
|
||||
command: "image_auth_token"
|
||||
}
|
||||
|
||||
export interface SyncRoom {
|
||||
meta: DBRoom
|
||||
timeline: TimelineRowTuple[]
|
||||
events: RawDBEvent[]
|
||||
state: Record<EventType, Record<string, EventRowID>>
|
||||
reset: boolean
|
||||
notifications: SyncNotification[]
|
||||
}
|
||||
|
||||
export interface SyncNotification {
|
||||
event_rowid: EventRowID
|
||||
sound: boolean
|
||||
}
|
||||
|
||||
export interface SyncCompleteData {
|
||||
|
@ -97,4 +107,5 @@ export type RPCEvent =
|
|||
TypingEvent |
|
||||
SendCompleteEvent |
|
||||
EventsDecryptedEvent |
|
||||
SyncCompleteEvent
|
||||
SyncCompleteEvent |
|
||||
ImageAuthTokenEvent
|
||||
|
|
|
@ -59,6 +59,9 @@ export interface DBRoom {
|
|||
|
||||
preview_event_rowid: EventRowID
|
||||
sorting_timestamp: number
|
||||
unread_highlights: number
|
||||
unread_notifications: number
|
||||
unread_messages: number
|
||||
|
||||
prev_batch: string
|
||||
}
|
||||
|
@ -66,6 +69,18 @@ export interface DBRoom {
|
|||
//eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export type UnknownEventContent = Record<string, any>
|
||||
|
||||
export enum UnreadType {
|
||||
None = 0b0000,
|
||||
Normal = 0b0001,
|
||||
Notify = 0b0010,
|
||||
Highlight = 0b0100,
|
||||
Sound = 0b1000,
|
||||
}
|
||||
|
||||
export interface LocalContent {
|
||||
sanitized_html?: TrustedHTML
|
||||
}
|
||||
|
||||
export interface BaseDBEvent {
|
||||
rowid: EventRowID
|
||||
timeline_rowid: TimelineRowID
|
||||
|
@ -79,6 +94,7 @@ export interface BaseDBEvent {
|
|||
|
||||
content: UnknownEventContent
|
||||
unsigned: EventUnsigned
|
||||
local_content?: LocalContent
|
||||
|
||||
transaction_id?: string
|
||||
|
||||
|
@ -91,6 +107,7 @@ export interface BaseDBEvent {
|
|||
|
||||
reactions?: Record<string, number>
|
||||
last_edit_rowid?: EventRowID
|
||||
unread_type: UnreadType
|
||||
}
|
||||
|
||||
export interface RawDBEvent extends BaseDBEvent {
|
||||
|
|
|
@ -31,6 +31,7 @@ const MainScreen = () => {
|
|||
.catch(err => console.error("Failed to load room state", err))
|
||||
}
|
||||
}, [client])
|
||||
client.store.switchRoom = setActiveRoom
|
||||
const clearActiveRoom = useCallback(() => setActiveRoomID(null), [])
|
||||
return <main className={`matrix-main ${activeRoom ? "room-selected" : ""}`}>
|
||||
<RoomList setActiveRoom={setActiveRoom} activeRoomID={activeRoomID} />
|
||||
|
|
|
@ -60,6 +60,11 @@ const Entry = ({ room, setActiveRoom, isActive, hidden }: RoomListEntryProps) =>
|
|||
<div className="room-name">{room.name}</div>
|
||||
{previewText && <div className="message-preview" title={previewText}>{croppedPreviewText}</div>}
|
||||
</div>
|
||||
{room.unread ? <div className="room-entry-unreads">
|
||||
<div className={`unread-count ${room.highlighted ? "highlighted" : ""}`}>
|
||||
{room.unread}
|
||||
</div>
|
||||
</div> : null}
|
||||
</div>
|
||||
}
|
||||
|
||||
|
|
|
@ -78,6 +78,26 @@ div.room-entry {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
> div.room-entry-unreads {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 3rem;
|
||||
|
||||
> div.unread-count {
|
||||
width: 1.5rem;
|
||||
height: 1.5rem;
|
||||
border-radius: 50%;
|
||||
background-color: green;
|
||||
text-align: center;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
|
||||
&.highlighted {
|
||||
background-color: darkred;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
img.avatar {
|
||||
|
|
|
@ -166,7 +166,20 @@ div.html-body {
|
|||
vertical-align: middle;
|
||||
}
|
||||
|
||||
span[data-mx-spoiler] {
|
||||
img.hicli-custom-emoji {
|
||||
vertical-align: middle;
|
||||
height: 24px;
|
||||
width: auto;
|
||||
max-width: 72px;
|
||||
}
|
||||
|
||||
img.hicli-sizeless-inline-img {
|
||||
height: 24px;
|
||||
width: auto;
|
||||
max-width: 72px;
|
||||
}
|
||||
|
||||
span[data-mx-spoiler], span.hicli-spoiler {
|
||||
filter: blur(4px);
|
||||
transition: filter .5s;
|
||||
cursor: pointer;
|
||||
|
|
|
@ -24,7 +24,11 @@ import { ReplyIDBody } from "./ReplyBody.tsx"
|
|||
import EncryptedBody from "./content/EncryptedBody.tsx"
|
||||
import HiddenEvent from "./content/HiddenEvent.tsx"
|
||||
import MemberBody from "./content/MemberBody.tsx"
|
||||
import { MediaMessageBody, TextMessageBody, UnknownMessageBody } from "./content/MessageBody.tsx"
|
||||
import {
|
||||
MediaMessageBody,
|
||||
TextMessageBody,
|
||||
UnknownMessageBody,
|
||||
} from "./content/MessageBody.tsx"
|
||||
import RedactedBody from "./content/RedactedBody.tsx"
|
||||
import { EventContentProps } from "./content/props.ts"
|
||||
import ErrorIcon from "../../icons/error.svg?react"
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import { CSSProperties, use, useMemo } from "react"
|
||||
import sanitizeHtml from "sanitize-html"
|
||||
import { CSSProperties, use } from "react"
|
||||
import { getEncryptedMediaURL, getMediaURL } from "@/api/media.ts"
|
||||
import type { EventType, MediaMessageEventContent, MessageEventContent } from "@/api/types"
|
||||
import { sanitizeHtmlParams } from "@/util/html.ts"
|
||||
import { calculateMediaSize } from "@/util/mediasize.ts"
|
||||
import { LightboxContext } from "../../Lightbox.tsx"
|
||||
import { EventContentProps } from "./props.ts"
|
||||
|
@ -32,14 +30,10 @@ const onClickHTML = (evt: React.MouseEvent<HTMLDivElement>) => {
|
|||
|
||||
export const TextMessageBody = ({ event }: EventContentProps) => {
|
||||
const content = event.content as MessageEventContent
|
||||
const __html = useMemo(() => {
|
||||
if (content.format === "org.matrix.custom.html") {
|
||||
return sanitizeHtml(content.formatted_body!, sanitizeHtmlParams)
|
||||
}
|
||||
return undefined
|
||||
}, [content.format, content.formatted_body])
|
||||
if (__html) {
|
||||
return <div onClick={onClickHTML} className="message-text html-body" dangerouslySetInnerHTML={{ __html }}/>
|
||||
if (event.local_content?.sanitized_html) {
|
||||
return <div onClick={onClickHTML} className="message-text html-body" dangerouslySetInnerHTML={{
|
||||
__html: event.local_content!.sanitized_html!,
|
||||
}}/>
|
||||
}
|
||||
return <div className="message-text plaintext-body">{content.body}</div>
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import { NonNullCachedEventDispatcher, useNonNullEventAsState } from "@/util/eventdispatcher.ts"
|
||||
|
||||
const focused = new NonNullCachedEventDispatcher(document.hasFocus())
|
||||
export const focused = new NonNullCachedEventDispatcher(document.hasFocus())
|
||||
window.addEventListener("focus", () => focused.emit(true))
|
||||
window.addEventListener("blur", () => focused.emit(false))
|
||||
|
||||
|
|
33
websocket.go
33
websocket.go
|
@ -28,10 +28,12 @@ import (
|
|||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/hicli"
|
||||
"maunium.net/go/mautrix/hicli/database"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/gomuks/pkg/hicli"
|
||||
"go.mau.fi/gomuks/pkg/hicli/database"
|
||||
)
|
||||
|
||||
func writeCmd(ctx context.Context, conn *websocket.Conn, cmd *hicli.JSONCommand) error {
|
||||
|
@ -120,6 +122,17 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
lastDataReceived := &atomic.Int64{}
|
||||
lastDataReceived.Store(time.Now().UnixMilli())
|
||||
const RecvTimeout = 60 * time.Second
|
||||
lastImageAuthTokenSent := time.Now()
|
||||
sendImageAuthToken := func() {
|
||||
err := writeCmd(ctx, conn, &hicli.JSONCommand{
|
||||
Command: "image_auth_token",
|
||||
Data: exerrors.Must(json.Marshal(gmx.generateImageToken())),
|
||||
})
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to write image auth token message")
|
||||
return
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer recoverPanic("event loop")
|
||||
defer closeOnce.Do(forceClose)
|
||||
|
@ -137,6 +150,10 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
log.Trace().Int64("req_id", cmd.RequestID).Msg("Sent outgoing event")
|
||||
}
|
||||
case <-ticker.C:
|
||||
if time.Since(lastImageAuthTokenSent) > 30*time.Minute {
|
||||
sendImageAuthToken()
|
||||
lastImageAuthTokenSent = time.Now()
|
||||
}
|
||||
if time.Now().UnixMilli()-lastDataReceived.Load() > RecvTimeout.Milliseconds() {
|
||||
log.Warn().Msg("No data received in a minute, closing connection")
|
||||
_ = conn.Close(StatusPingTimeout, "Ping timeout")
|
||||
|
@ -187,6 +204,7 @@ func (gmx *Gomuks) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|||
log.Err(initErr).Msg("Failed to write init message")
|
||||
return
|
||||
}
|
||||
go sendImageAuthToken()
|
||||
go gmx.sendInitialData(ctx, conn)
|
||||
log.Debug().Msg("Connection initialization complete")
|
||||
var closeErr websocket.CloseError
|
||||
|
@ -242,10 +260,11 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
|
|||
}
|
||||
maxTS = room.SortingTimestamp.Time
|
||||
syncRoom := &hicli.SyncRoom{
|
||||
Meta: room,
|
||||
Events: make([]*database.Event, 0, 2),
|
||||
Timeline: make([]database.TimelineRowTuple, 0),
|
||||
State: map[event.Type]map[string]database.EventRowID{},
|
||||
Meta: room,
|
||||
Events: make([]*database.Event, 0, 2),
|
||||
Timeline: make([]database.TimelineRowTuple, 0),
|
||||
State: map[event.Type]map[string]database.EventRowID{},
|
||||
Notifications: make([]hicli.SyncNotification, 0),
|
||||
}
|
||||
payload.Rooms[room.ID] = syncRoom
|
||||
if room.PreviewEventRowID != 0 {
|
||||
|
@ -255,6 +274,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
|
|||
return
|
||||
}
|
||||
if previewEvent != nil {
|
||||
gmx.Client.ReprocessExistingEvent(ctx, previewEvent)
|
||||
previewMember, err := gmx.Client.DB.CurrentState.Get(ctx, room.ID, event.StateMember, previewEvent.Sender.String())
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get preview member event for room")
|
||||
|
@ -269,6 +289,7 @@ func (gmx *Gomuks) sendInitialData(ctx context.Context, conn *websocket.Conn) {
|
|||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get last edit for preview event")
|
||||
} else if lastEdit != nil {
|
||||
gmx.Client.ReprocessExistingEvent(ctx, lastEdit)
|
||||
syncRoom.Events = append(syncRoom.Events, lastEdit)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue