From ee82685b4e405e54efce38620fa71fd8528d8364 Mon Sep 17 00:00:00 2001 From: ErrorNoInternet <errornointernet@envs.net> Date: Sat, 15 Mar 2025 22:51:55 -0400 Subject: [PATCH] fix(matrix): properly handle sessions --- Cargo.lock | 1 + Cargo.toml | 1 + src/events.rs | 14 ++++-- src/matrix/bot.rs | 6 +-- src/matrix/mod.rs | 119 +++++++++++++++++++++++++++++++++++----------- 5 files changed, 106 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f05f1ee..204d0ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1769,6 +1769,7 @@ dependencies = [ "mlua", "ncr", "parking_lot", + "serde", "serde_json", "tokio", "zip", diff --git a/Cargo.toml b/Cargo.toml index f948448..e8056a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ mimalloc = { version = "0", optional = true } mlua = { version = "0", features = ["async", "luajit", "send"] } ncr = { version = "0", features = ["cfb8", "ecb", "gcm"] } parking_lot = "0" +serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1", features = ["macros"] } zip = { version = "2", default-features = false, features = ["flate2"] } diff --git a/src/events.rs b/src/events.rs index 9f02cbb..e1d56e0 100644 --- a/src/events.rs +++ b/src/events.rs @@ -210,12 +210,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> Result< exit(0); })?; + #[cfg(feature = "matrix")] + matrix_init(&client, state.clone()); + let globals = state.lua.globals(); lua_init(client, &state, &globals).await?; - #[cfg(feature = "matrix")] - matrix_init(state.clone(), &globals); - let Some(address): Option<SocketAddr> = globals .get::<String>("HttpAddress") .ok() @@ -276,13 +276,17 @@ async fn lua_init(client: Client, state: &State, globals: &Table) -> Result<()> } #[cfg(feature = "matrix")] -fn matrix_init(state: State, globals: &Table) { +fn matrix_init(client: &Client, state: State) { + let globals = state.lua.globals(); if let Ok(homeserver_url) = globals.get::<String>("MatrixHomeserverUrl") && let Ok(username) = globals.get::<String>("MatrixUsername") && let Ok(password) = globals.get::<String>("MatrixPassword") { + let name = client.username(); tokio::spawn(async move { - if let Err(error) = matrix::login(state, homeserver_url, username, &password).await { + if let Err(error) = + matrix::login(homeserver_url, username, &password, state, globals, name).await + { error!("failed to log into matrix account: {error:?}"); } }); diff --git a/src/matrix/bot.rs b/src/matrix/bot.rs index 326547b..2758c95 100644 --- a/src/matrix/bot.rs +++ b/src/matrix/bot.rs @@ -1,4 +1,4 @@ -use super::{COMMAND_PREFIX, Context}; +use super::{COMMAND_PREFIX, MatrixContext}; use crate::{ events::call_listeners, lua::{self, matrix::room::Room as LuaRoom}, @@ -19,7 +19,7 @@ use tokio::time::sleep; pub async fn on_regular_room_message( event: OriginalSyncRoomMessageEvent, room: Room, - ctx: Ctx<Context>, + ctx: Ctx<MatrixContext>, ) -> Result<()> { if room.state() != RoomState::Joined { return Ok(()); @@ -90,7 +90,7 @@ pub async fn on_stripped_state_member( member: StrippedRoomMemberEvent, client: Client, room: Room, - ctx: Ctx<Context>, + ctx: Ctx<MatrixContext>, ) -> Result<()> { if let Some(user_id) = client.user_id() && member.state_key == user_id diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index 66b789d..f867fe7 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -2,56 +2,121 @@ mod bot; mod verification; use crate::{State, lua::matrix::client::Client as LuaClient}; -use anyhow::Result; +use anyhow::{Context, Result}; use bot::{on_regular_room_message, on_stripped_state_member}; -use matrix_sdk::{Client, config::SyncSettings}; -use std::{fs, sync::Arc}; +use log::{error, warn}; +use matrix_sdk::{ + Client, Error, LoopCtrl, authentication::matrix::MatrixSession, config::SyncSettings, +}; +use mlua::Table; +use serde::{Deserialize, Serialize}; +use std::{path::Path, sync::Arc}; +use tokio::fs; use verification::{on_device_key_verification_request, on_room_message_verification_request}; -const COMMAND_PREFIX: &str = "ErrorNoWatcher"; - #[derive(Clone)] -pub struct Context { +pub struct MatrixContext { state: State, } +#[derive(Clone, Serialize, Deserialize)] +struct Session { + #[serde(skip_serializing_if = "Option::is_none")] + sync_token: Option<String>, + user_session: MatrixSession, +} + +async fn persist_sync_token( + session_file: &Path, + session: &mut Session, + sync_token: String, +) -> Result<()> { + session.sync_token = Some(sync_token); + fs::write(session_file, serde_json::to_string(&session)?).await?; + Ok(()) +} + pub async fn login( - state: State, homeserver_url: String, username: String, password: &str, + state: State, + globals: Table, + name: String, ) -> Result<()> { - let mut client = Client::builder().homeserver_url(homeserver_url); - if let Some(db_path) = dirs::data_dir().map(|path| path.join("errornowatcher").join("matrix")) - && fs::create_dir_all(&db_path).is_ok() + let root_dir = dirs::data_dir() + .context("no data directory")? + .join("errornowatcher") + .join(&name) + .join("matrix"); + + let mut builder = Client::builder().homeserver_url(homeserver_url); + if !fs::try_exists(&root_dir).await.unwrap_or_default() + && let Err(error) = fs::create_dir_all(&root_dir).await { - client = client.sqlite_store(db_path, None); + warn!("failed to create directory for matrix sqlite3 store: {error:?}"); + } else { + builder = builder.sqlite_store(&root_dir, None); + } + let client = builder.build().await?; + + let mut new_session; + let session_file = root_dir.join("session.json"); + let mut sync_settings = SyncSettings::default(); + if let Some(session) = fs::read_to_string(&session_file) + .await + .ok() + .and_then(|data| serde_json::from_str::<Session>(&data).ok()) + { + new_session = session.clone(); + if let Some(sync_token) = session.sync_token { + sync_settings = sync_settings.token(sync_token); + } + client.restore_session(session.user_session).await?; + } else { + let matrix_auth = client.matrix_auth(); + matrix_auth + .login_username(username, password) + .initial_device_display_name(&name) + .await?; + + new_session = Session { + user_session: matrix_auth.session().context("should have session")?, + sync_token: None, + }; + fs::write(&session_file, serde_json::to_string(&new_session)?).await?; } - let client = Arc::new(client.build().await?); - client - .matrix_auth() - .login_username(username, password) - .device_id("ERRORNOWATCHER") - .initial_device_display_name("ErrorNoWatcher") - .await?; - + client.add_event_handler_context(MatrixContext { state }); client.add_event_handler(on_stripped_state_member); - let response = client.sync_once(SyncSettings::default()).await?; + loop { + match client.sync_once(sync_settings.clone()).await { + Ok(response) => { + sync_settings = sync_settings.token(response.next_batch.clone()); + persist_sync_token(&session_file, &mut new_session, response.next_batch).await?; + break; + } + Err(error) => { + error!("failed to do initial sync: {error:?}"); + } + } + } client.add_event_handler(on_device_key_verification_request); client.add_event_handler(on_room_message_verification_request); client.add_event_handler(on_regular_room_message); - state - .lua - .globals() - .set("matrix", LuaClient(client.clone()))?; + let client = Arc::new(client); + globals.set("matrix", LuaClient(client.clone()))?; - client.add_event_handler_context(Context { state }); client - .sync(SyncSettings::default().token(response.next_batch)) + .sync_with_result_callback(sync_settings, |sync_result| async { + let mut new_session = new_session.clone(); + persist_sync_token(&session_file, &mut new_session, sync_result?.next_batch) + .await + .map_err(|err| Error::UnknownError(err.into()))?; + Ok(LoopCtrl::Continue) + }) .await?; - Ok(()) }