From ee82685b4e405e54efce38620fa71fd8528d8364 Mon Sep 17 00:00:00 2001
From: ErrorNoInternet <errornointernet@envs.net>
Date: Sat, 15 Mar 2025 22:51:55 -0400
Subject: [PATCH] fix(matrix): properly handle sessions

---
 Cargo.lock        |   1 +
 Cargo.toml        |   1 +
 src/events.rs     |  14 ++++--
 src/matrix/bot.rs |   6 +--
 src/matrix/mod.rs | 119 +++++++++++++++++++++++++++++++++++-----------
 5 files changed, 106 insertions(+), 35 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index f05f1ee..204d0ec 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1769,6 +1769,7 @@ dependencies = [
  "mlua",
  "ncr",
  "parking_lot",
+ "serde",
  "serde_json",
  "tokio",
  "zip",
diff --git a/Cargo.toml b/Cargo.toml
index f948448..e8056a8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -40,6 +40,7 @@ mimalloc = { version = "0", optional = true }
 mlua = { version = "0", features = ["async", "luajit", "send"] }
 ncr = { version = "0", features = ["cfb8", "ecb", "gcm"] }
 parking_lot = "0"
+serde = { version = "1", features = ["derive"] }
 serde_json = "1"
 tokio = { version = "1", features = ["macros"] }
 zip = { version = "2", default-features = false, features = ["flate2"] }
diff --git a/src/events.rs b/src/events.rs
index 9f02cbb..e1d56e0 100644
--- a/src/events.rs
+++ b/src/events.rs
@@ -210,12 +210,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> Result<
                 exit(0);
             })?;
 
+            #[cfg(feature = "matrix")]
+            matrix_init(&client, state.clone());
+
             let globals = state.lua.globals();
             lua_init(client, &state, &globals).await?;
 
-            #[cfg(feature = "matrix")]
-            matrix_init(state.clone(), &globals);
-
             let Some(address): Option<SocketAddr> = globals
                 .get::<String>("HttpAddress")
                 .ok()
@@ -276,13 +276,17 @@ async fn lua_init(client: Client, state: &State, globals: &Table) -> Result<()>
 }
 
 #[cfg(feature = "matrix")]
-fn matrix_init(state: State, globals: &Table) {
+fn matrix_init(client: &Client, state: State) {
+    let globals = state.lua.globals();
     if let Ok(homeserver_url) = globals.get::<String>("MatrixHomeserverUrl")
         && let Ok(username) = globals.get::<String>("MatrixUsername")
         && let Ok(password) = globals.get::<String>("MatrixPassword")
     {
+        let name = client.username();
         tokio::spawn(async move {
-            if let Err(error) = matrix::login(state, homeserver_url, username, &password).await {
+            if let Err(error) =
+                matrix::login(homeserver_url, username, &password, state, globals, name).await
+            {
                 error!("failed to log into matrix account: {error:?}");
             }
         });
diff --git a/src/matrix/bot.rs b/src/matrix/bot.rs
index 326547b..2758c95 100644
--- a/src/matrix/bot.rs
+++ b/src/matrix/bot.rs
@@ -1,4 +1,4 @@
-use super::{COMMAND_PREFIX, Context};
+use super::{COMMAND_PREFIX, MatrixContext};
 use crate::{
     events::call_listeners,
     lua::{self, matrix::room::Room as LuaRoom},
@@ -19,7 +19,7 @@ use tokio::time::sleep;
 pub async fn on_regular_room_message(
     event: OriginalSyncRoomMessageEvent,
     room: Room,
-    ctx: Ctx<Context>,
+    ctx: Ctx<MatrixContext>,
 ) -> Result<()> {
     if room.state() != RoomState::Joined {
         return Ok(());
@@ -90,7 +90,7 @@ pub async fn on_stripped_state_member(
     member: StrippedRoomMemberEvent,
     client: Client,
     room: Room,
-    ctx: Ctx<Context>,
+    ctx: Ctx<MatrixContext>,
 ) -> Result<()> {
     if let Some(user_id) = client.user_id()
         && member.state_key == user_id
diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs
index 66b789d..f867fe7 100644
--- a/src/matrix/mod.rs
+++ b/src/matrix/mod.rs
@@ -2,56 +2,121 @@ mod bot;
 mod verification;
 
 use crate::{State, lua::matrix::client::Client as LuaClient};
-use anyhow::Result;
+use anyhow::{Context, Result};
 use bot::{on_regular_room_message, on_stripped_state_member};
-use matrix_sdk::{Client, config::SyncSettings};
-use std::{fs, sync::Arc};
+use log::{error, warn};
+use matrix_sdk::{
+    Client, Error, LoopCtrl, authentication::matrix::MatrixSession, config::SyncSettings,
+};
+use mlua::Table;
+use serde::{Deserialize, Serialize};
+use std::{path::Path, sync::Arc};
+use tokio::fs;
 use verification::{on_device_key_verification_request, on_room_message_verification_request};
 
-const COMMAND_PREFIX: &str = "ErrorNoWatcher";
-
 #[derive(Clone)]
-pub struct Context {
+pub struct MatrixContext {
     state: State,
 }
 
+#[derive(Clone, Serialize, Deserialize)]
+struct Session {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    sync_token: Option<String>,
+    user_session: MatrixSession,
+}
+
+async fn persist_sync_token(
+    session_file: &Path,
+    session: &mut Session,
+    sync_token: String,
+) -> Result<()> {
+    session.sync_token = Some(sync_token);
+    fs::write(session_file, serde_json::to_string(&session)?).await?;
+    Ok(())
+}
+
 pub async fn login(
-    state: State,
     homeserver_url: String,
     username: String,
     password: &str,
+    state: State,
+    globals: Table,
+    name: String,
 ) -> Result<()> {
-    let mut client = Client::builder().homeserver_url(homeserver_url);
-    if let Some(db_path) = dirs::data_dir().map(|path| path.join("errornowatcher").join("matrix"))
-        && fs::create_dir_all(&db_path).is_ok()
+    let root_dir = dirs::data_dir()
+        .context("no data directory")?
+        .join("errornowatcher")
+        .join(&name)
+        .join("matrix");
+
+    let mut builder = Client::builder().homeserver_url(homeserver_url);
+    if !fs::try_exists(&root_dir).await.unwrap_or_default()
+        && let Err(error) = fs::create_dir_all(&root_dir).await
     {
-        client = client.sqlite_store(db_path, None);
+        warn!("failed to create directory for matrix sqlite3 store: {error:?}");
+    } else {
+        builder = builder.sqlite_store(&root_dir, None);
+    }
+    let client = builder.build().await?;
+
+    let mut new_session;
+    let session_file = root_dir.join("session.json");
+    let mut sync_settings = SyncSettings::default();
+    if let Some(session) = fs::read_to_string(&session_file)
+        .await
+        .ok()
+        .and_then(|data| serde_json::from_str::<Session>(&data).ok())
+    {
+        new_session = session.clone();
+        if let Some(sync_token) = session.sync_token {
+            sync_settings = sync_settings.token(sync_token);
+        }
+        client.restore_session(session.user_session).await?;
+    } else {
+        let matrix_auth = client.matrix_auth();
+        matrix_auth
+            .login_username(username, password)
+            .initial_device_display_name(&name)
+            .await?;
+
+        new_session = Session {
+            user_session: matrix_auth.session().context("should have session")?,
+            sync_token: None,
+        };
+        fs::write(&session_file, serde_json::to_string(&new_session)?).await?;
     }
 
-    let client = Arc::new(client.build().await?);
-    client
-        .matrix_auth()
-        .login_username(username, password)
-        .device_id("ERRORNOWATCHER")
-        .initial_device_display_name("ErrorNoWatcher")
-        .await?;
-
+    client.add_event_handler_context(MatrixContext { state });
     client.add_event_handler(on_stripped_state_member);
-    let response = client.sync_once(SyncSettings::default()).await?;
+    loop {
+        match client.sync_once(sync_settings.clone()).await {
+            Ok(response) => {
+                sync_settings = sync_settings.token(response.next_batch.clone());
+                persist_sync_token(&session_file, &mut new_session, response.next_batch).await?;
+                break;
+            }
+            Err(error) => {
+                error!("failed to do initial sync: {error:?}");
+            }
+        }
+    }
 
     client.add_event_handler(on_device_key_verification_request);
     client.add_event_handler(on_room_message_verification_request);
     client.add_event_handler(on_regular_room_message);
 
-    state
-        .lua
-        .globals()
-        .set("matrix", LuaClient(client.clone()))?;
+    let client = Arc::new(client);
+    globals.set("matrix", LuaClient(client.clone()))?;
 
-    client.add_event_handler_context(Context { state });
     client
-        .sync(SyncSettings::default().token(response.next_batch))
+        .sync_with_result_callback(sync_settings, |sync_result| async {
+            let mut new_session = new_session.clone();
+            persist_sync_token(&session_file, &mut new_session, sync_result?.next_batch)
+                .await
+                .map_err(|err| Error::UnknownError(err.into()))?;
+            Ok(LoopCtrl::Continue)
+        })
         .await?;
-
     Ok(())
 }