diff --git a/lib/events.lua b/lib/events.lua index 8a960c4..7bd4c55 100644 --- a/lib/events.lua +++ b/lib/events.lua @@ -1,7 +1,7 @@ local center = { x = 0, z = 0 } local radius = 100 -function on_tick() +function log_player_positions() local entities = client:find_entities(function(e) return e.kind == "minecraft:player" and e.position.x > center.x - radius + 1 @@ -17,8 +17,14 @@ end function on_init() info("client initialized, setting information...") client:set_client_information({ view_distance = 16 }) -end -function on_login() - info("player successfully logged in!") + add_listener("login", function() + info("player successfully logged in!") + end) + + add_listener("death", function() + warn("player died!") + end, "warn_player_died") + + add_listener("tick", log_player_positions) end diff --git a/src/events.rs b/src/events.rs index fec8620..6dbd624 100644 --- a/src/events.rs +++ b/src/events.rs @@ -1,9 +1,14 @@ -use crate::{State, commands::CommandSource, http::serve, lua}; +use crate::{ + State, + commands::CommandSource, + http::serve, + lua::{self, events::register_functions}, +}; use azalea::prelude::*; use hyper::{server::conn::http1, service::service_fn}; use hyper_util::rt::TokioIo; use log::{debug, error, info, trace}; -use mlua::{Function, IntoLuaMulti, Table}; +use mlua::{Function, IntoLuaMulti}; use tokio::net::TcpListener; pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow::Result<()> { @@ -32,23 +37,22 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: CommandSource { client, message, - state, + state: state.clone(), } .reply(&format!("{error:?}")); } } - call_lua_handler(&globals, "on_chat", ()); + call_listeners(&state, "chat", formatted_message.to_string()).await; } Event::Death(Some(packet)) => { let death_data = state.lua.create_table()?; death_data.set("message", packet.message.to_string())?; death_data.set("player_id", packet.player_id)?; - - call_lua_handler(&globals, "on_death", death_data); + call_listeners(&state, "death", death_data).await; } - Event::Tick => call_lua_handler(&globals, "on_tick", ()), - Event::Login => call_lua_handler(&globals, "on_login", ()), + Event::Login => call_listeners(&state, "login", ()).await, + Event::Tick => call_listeners(&state, "tick", ()).await, Event::Init => { debug!("client initialized"); @@ -58,7 +62,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: inner: Some(client), }, )?; - call_lua_handler(&globals, "on_init", ()); + register_functions(&state.lua, &globals, state.clone()).await?; + if let Ok(on_init) = globals.get::("on_init") + && let Err(error) = on_init.call::<()>(()) + { + error!("failed to call lua on_init function: {error:?}"); + } if let Some(address) = state.http_address { let listener = TcpListener::bind(address).await.map_err(|error| { @@ -77,7 +86,7 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: async move { serve(request, request_state).await } }); - tokio::task::spawn(async move { + tokio::spawn(async move { if let Err(error) = http1::Builder::new() .serve_connection(TokioIo::new(stream), service) .await @@ -94,10 +103,12 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: Ok(()) } -fn call_lua_handler(globals: &Table, name: &str, data: T) { - if let Ok(handler) = globals.get::(name) - && let Err(error) = handler.call::<()>(data) - { - error!("failed to call lua {name} function: {error:?}"); +async fn call_listeners(state: &State, event_type: &str, data: T) { + if let Some(listeners) = state.event_listeners.lock().await.get(event_type) { + for (_, listener) in listeners { + if let Err(error) = listener.call_async::<()>(data.clone()).await { + error!("failed to call lua event listener for {event_type}: {error:?}"); + } + } } } diff --git a/src/lua/events.rs b/src/lua/events.rs new file mode 100644 index 0000000..5b0e1c3 --- /dev/null +++ b/src/lua/events.rs @@ -0,0 +1,79 @@ +use crate::State; +use futures::executor::block_on; +use mlua::{Function, Lua, Result, Table}; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub async fn register_functions(lua: &Lua, globals: &Table, state: State) -> Result<()> { + let l = state.event_listeners.clone(); + globals.set( + "add_listener", + lua.create_function( + move |_, (event_type, callback, id): (String, Function, Option)| { + let mut l = block_on(l.lock()); + + l.entry(event_type).or_default().push(( + id.unwrap_or(callback.info().name.unwrap_or(format!( + "anonymous @ {}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + ))), + callback, + )); + Ok(()) + }, + )?, + )?; + + let l = state.event_listeners.clone(); + globals.set( + "remove_listener", + lua.create_function(move |_, (event_type, target_id): (String, String)| { + let mut l = block_on(l.lock()); + + let empty = if let Some(listeners) = l.get_mut(&event_type) { + listeners.retain(|(id, _)| target_id != *id); + listeners.is_empty() + } else { + false + }; + if empty { + l.remove(&event_type); + } + + Ok(()) + })?, + )?; + + globals.set( + "get_listeners", + lua.create_function(move |lua, (): ()| { + let l = block_on(state.event_listeners.lock()); + + let listeners = lua.create_table()?; + for (event_type, callbacks) in l.iter() { + let type_listeners = lua.create_table()?; + for (id, callback) in callbacks { + let listener = lua.create_table()?; + let i = callback.info(); + if let Some(n) = i.name { + listener.set("name", n)?; + } + if let Some(l) = i.line_defined { + listener.set("line_defined", l)?; + } + if let Some(s) = i.source { + listener.set("source", s)?; + } + type_listeners.set(id.to_owned(), listener)?; + } + listeners.set(event_type.to_owned(), type_listeners)?; + } + + Ok(listeners) + })?, + )?; + + Ok(()) +} diff --git a/src/lua/mod.rs b/src/lua/mod.rs index 59bfbfa..c83b62d 100644 --- a/src/lua/mod.rs +++ b/src/lua/mod.rs @@ -2,6 +2,7 @@ pub mod block; pub mod client; pub mod container; pub mod direction; +pub mod events; pub mod logging; pub mod vec3; @@ -26,8 +27,8 @@ pub fn register_functions(lua: &Lua, globals: &Table) -> mlua::Result<()> { })?, )?; - logging::register_functions(lua, globals)?; - block::register_functions(lua, globals) + block::register_functions(lua, globals)?; + logging::register_functions(lua, globals) } pub fn reload(lua: &Lua) -> Result<(), Error> { diff --git a/src/main.rs b/src/main.rs index 0bb76c9..c2f336e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,16 +11,19 @@ use clap::Parser; use commands::{CommandSource, register}; use events::handle_event; use futures::lock::Mutex; -use mlua::Lua; -use std::{net::SocketAddr, path::PathBuf, sync::Arc}; +use mlua::{Function, Lua}; +use std::{collections::HashMap, net::SocketAddr, path::PathBuf, sync::Arc}; const DEFAULT_SCRIPT_PATH: &str = "errornowatcher.lua"; +type ListenerMap = HashMap>; + #[derive(Default, Clone, Component)] pub struct State { lua: Lua, - http_address: Option, + event_listeners: Arc>, commands: Arc>>, + http_address: Option, } #[tokio::main] @@ -53,8 +56,9 @@ async fn main() -> anyhow::Result<()> { .set_handler(handle_event) .set_state(State { lua, - http_address: args.http_address, + event_listeners: Arc::new(Mutex::new(HashMap::new())), commands: Arc::new(commands), + http_address: args.http_address, }) .start( if username.contains('@') {