From bd6698c4b4e011bb4d356b1a54482f3333961d71 Mon Sep 17 00:00:00 2001
From: ErrorNoInternet <errornointernet@envs.net>
Date: Thu, 20 Feb 2025 21:07:41 -0500
Subject: [PATCH] feat: add event listeners

---
 lib/events.lua    | 14 ++++++---
 src/events.rs     | 41 +++++++++++++++---------
 src/lua/events.rs | 79 +++++++++++++++++++++++++++++++++++++++++++++++
 src/lua/mod.rs    |  5 +--
 src/main.rs       | 12 ++++---
 5 files changed, 126 insertions(+), 25 deletions(-)
 create mode 100644 src/lua/events.rs

diff --git a/lib/events.lua b/lib/events.lua
index 8a960c4..7bd4c55 100644
--- a/lib/events.lua
+++ b/lib/events.lua
@@ -1,7 +1,7 @@
 local center = { x = 0, z = 0 }
 local radius = 100
 
-function on_tick()
+function log_player_positions()
 	local entities = client:find_entities(function(e)
 		return e.kind == "minecraft:player"
 			and e.position.x > center.x - radius + 1
@@ -17,8 +17,14 @@ end
 function on_init()
 	info("client initialized, setting information...")
 	client:set_client_information({ view_distance = 16 })
-end
 
-function on_login()
-	info("player successfully logged in!")
+	add_listener("login", function()
+		info("player successfully logged in!")
+	end)
+
+	add_listener("death", function()
+		warn("player died!")
+	end, "warn_player_died")
+
+	add_listener("tick", log_player_positions)
 end
diff --git a/src/events.rs b/src/events.rs
index fec8620..6dbd624 100644
--- a/src/events.rs
+++ b/src/events.rs
@@ -1,9 +1,14 @@
-use crate::{State, commands::CommandSource, http::serve, lua};
+use crate::{
+    State,
+    commands::CommandSource,
+    http::serve,
+    lua::{self, events::register_functions},
+};
 use azalea::prelude::*;
 use hyper::{server::conn::http1, service::service_fn};
 use hyper_util::rt::TokioIo;
 use log::{debug, error, info, trace};
-use mlua::{Function, IntoLuaMulti, Table};
+use mlua::{Function, IntoLuaMulti};
 use tokio::net::TcpListener;
 
 pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow::Result<()> {
@@ -32,23 +37,22 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow:
                     CommandSource {
                         client,
                         message,
-                        state,
+                        state: state.clone(),
                     }
                     .reply(&format!("{error:?}"));
                 }
             }
 
-            call_lua_handler(&globals, "on_chat", ());
+            call_listeners(&state, "chat", formatted_message.to_string()).await;
         }
         Event::Death(Some(packet)) => {
             let death_data = state.lua.create_table()?;
             death_data.set("message", packet.message.to_string())?;
             death_data.set("player_id", packet.player_id)?;
-
-            call_lua_handler(&globals, "on_death", death_data);
+            call_listeners(&state, "death", death_data).await;
         }
-        Event::Tick => call_lua_handler(&globals, "on_tick", ()),
-        Event::Login => call_lua_handler(&globals, "on_login", ()),
+        Event::Login => call_listeners(&state, "login", ()).await,
+        Event::Tick => call_listeners(&state, "tick", ()).await,
         Event::Init => {
             debug!("client initialized");
 
@@ -58,7 +62,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow:
                     inner: Some(client),
                 },
             )?;
-            call_lua_handler(&globals, "on_init", ());
+            register_functions(&state.lua, &globals, state.clone()).await?;
+            if let Ok(on_init) = globals.get::<Function>("on_init")
+                && let Err(error) = on_init.call::<()>(())
+            {
+                error!("failed to call lua on_init function: {error:?}");
+            }
 
             if let Some(address) = state.http_address {
                 let listener = TcpListener::bind(address).await.map_err(|error| {
@@ -77,7 +86,7 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow:
                         async move { serve(request, request_state).await }
                     });
 
-                    tokio::task::spawn(async move {
+                    tokio::spawn(async move {
                         if let Err(error) = http1::Builder::new()
                             .serve_connection(TokioIo::new(stream), service)
                             .await
@@ -94,10 +103,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow:
     Ok(())
 }
 
-fn call_lua_handler<T: IntoLuaMulti>(globals: &Table, name: &str, data: T) {
-    if let Ok(handler) = globals.get::<Function>(name)
-        && let Err(error) = handler.call::<()>(data)
-    {
-        error!("failed to call lua {name} function: {error:?}");
+async fn call_listeners<T: Clone + IntoLuaMulti>(state: &State, event_type: &str, data: T) {
+    if let Some(listeners) = state.event_listeners.lock().await.get(event_type) {
+        for (_, listener) in listeners {
+            if let Err(error) = listener.call_async::<()>(data.clone()).await {
+                error!("failed to call lua event listener for {event_type}: {error:?}");
+            }
+        }
     }
 }
diff --git a/src/lua/events.rs b/src/lua/events.rs
new file mode 100644
index 0000000..5b0e1c3
--- /dev/null
+++ b/src/lua/events.rs
@@ -0,0 +1,79 @@
+use crate::State;
+use futures::executor::block_on;
+use mlua::{Function, Lua, Result, Table};
+use std::time::{SystemTime, UNIX_EPOCH};
+
+pub async fn register_functions(lua: &Lua, globals: &Table, state: State) -> Result<()> {
+    let l = state.event_listeners.clone();
+    globals.set(
+        "add_listener",
+        lua.create_function(
+            move |_, (event_type, callback, id): (String, Function, Option<String>)| {
+                let mut l = block_on(l.lock());
+
+                l.entry(event_type).or_default().push((
+                    id.unwrap_or(callback.info().name.unwrap_or(format!(
+                            "anonymous @ {}",
+                            SystemTime::now()
+                                .duration_since(UNIX_EPOCH)
+                                .unwrap_or_default()
+                                .as_millis()
+                        ))),
+                    callback,
+                ));
+                Ok(())
+            },
+        )?,
+    )?;
+
+    let l = state.event_listeners.clone();
+    globals.set(
+        "remove_listener",
+        lua.create_function(move |_, (event_type, target_id): (String, String)| {
+            let mut l = block_on(l.lock());
+
+            let empty = if let Some(listeners) = l.get_mut(&event_type) {
+                listeners.retain(|(id, _)| target_id != *id);
+                listeners.is_empty()
+            } else {
+                false
+            };
+            if empty {
+                l.remove(&event_type);
+            }
+
+            Ok(())
+        })?,
+    )?;
+
+    globals.set(
+        "get_listeners",
+        lua.create_function(move |lua, (): ()| {
+            let l = block_on(state.event_listeners.lock());
+
+            let listeners = lua.create_table()?;
+            for (event_type, callbacks) in l.iter() {
+                let type_listeners = lua.create_table()?;
+                for (id, callback) in callbacks {
+                    let listener = lua.create_table()?;
+                    let i = callback.info();
+                    if let Some(n) = i.name {
+                        listener.set("name", n)?;
+                    }
+                    if let Some(l) = i.line_defined {
+                        listener.set("line_defined", l)?;
+                    }
+                    if let Some(s) = i.source {
+                        listener.set("source", s)?;
+                    }
+                    type_listeners.set(id.to_owned(), listener)?;
+                }
+                listeners.set(event_type.to_owned(), type_listeners)?;
+            }
+
+            Ok(listeners)
+        })?,
+    )?;
+
+    Ok(())
+}
diff --git a/src/lua/mod.rs b/src/lua/mod.rs
index 59bfbfa..c83b62d 100644
--- a/src/lua/mod.rs
+++ b/src/lua/mod.rs
@@ -2,6 +2,7 @@ pub mod block;
 pub mod client;
 pub mod container;
 pub mod direction;
+pub mod events;
 pub mod logging;
 pub mod vec3;
 
@@ -26,8 +27,8 @@ pub fn register_functions(lua: &Lua, globals: &Table) -> mlua::Result<()> {
         })?,
     )?;
 
-    logging::register_functions(lua, globals)?;
-    block::register_functions(lua, globals)
+    block::register_functions(lua, globals)?;
+    logging::register_functions(lua, globals)
 }
 
 pub fn reload(lua: &Lua) -> Result<(), Error> {
diff --git a/src/main.rs b/src/main.rs
index 0bb76c9..c2f336e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -11,16 +11,19 @@ use clap::Parser;
 use commands::{CommandSource, register};
 use events::handle_event;
 use futures::lock::Mutex;
-use mlua::Lua;
-use std::{net::SocketAddr, path::PathBuf, sync::Arc};
+use mlua::{Function, Lua};
+use std::{collections::HashMap, net::SocketAddr, path::PathBuf, sync::Arc};
 
 const DEFAULT_SCRIPT_PATH: &str = "errornowatcher.lua";
 
+type ListenerMap = HashMap<String, Vec<(String, Function)>>;
+
 #[derive(Default, Clone, Component)]
 pub struct State {
     lua: Lua,
-    http_address: Option<SocketAddr>,
+    event_listeners: Arc<Mutex<ListenerMap>>,
     commands: Arc<CommandDispatcher<Mutex<CommandSource>>>,
+    http_address: Option<SocketAddr>,
 }
 
 #[tokio::main]
@@ -53,8 +56,9 @@ async fn main() -> anyhow::Result<()> {
         .set_handler(handle_event)
         .set_state(State {
             lua,
-            http_address: args.http_address,
+            event_listeners: Arc::new(Mutex::new(HashMap::new())),
             commands: Arc::new(commands),
+            http_address: args.http_address,
         })
         .start(
             if username.contains('@') {