diff --git a/lib/events.lua b/lib/events.lua index 25cc95b..0bd27da 100644 --- a/lib/events.lua +++ b/lib/events.lua @@ -14,17 +14,17 @@ function log_player_positions() end end -function on_init() +add_listener("init", function() info("client initialized, setting information...") client:set_client_information({ view_distance = 16 }) +end) - add_listener("login", function() - info("player successfully logged in!") - end) +add_listener("login", function() + info("player successfully logged in!") +end) - add_listener("death", function() - warn(string.format("player died at %.1f %.1f %.1f!", client.position.x, client.position.y, client.position.z)) - end, "warn_player_died") +add_listener("death", function() + warn(string.format("player died at %.1f %.1f %.1f!", client.position.x, client.position.y, client.position.z)) +end, "warn_player_died") - add_listener("tick", log_player_positions) -end +add_listener("tick", log_player_positions) diff --git a/src/events.rs b/src/events.rs index ac01c6c..ec2bea8 100644 --- a/src/events.rs +++ b/src/events.rs @@ -4,19 +4,18 @@ use crate::{ State, commands::CommandSource, http::serve, - lua::{self, events::register_functions, player::Player}, + lua::{self, player::Player}, }; use azalea::{prelude::*, protocol::packets::game::ClientboundGamePacket}; use hyper::{server::conn::http1, service::service_fn}; use hyper_util::rt::TokioIo; use log::{debug, error, info, trace}; -use mlua::{Function, IntoLuaMulti}; +use mlua::IntoLuaMulti; use tokio::net::TcpListener; #[allow(clippy::too_many_lines)] pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow::Result<()> { state.lua.gc_stop(); - let globals = state.lua.globals(); match event { Event::AddPlayer(player_info) => { @@ -26,7 +25,7 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: let formatted_message = message.message(); info!("{}", formatted_message.to_ansi()); - let owners = globals.get::>("Owners")?; + let owners = state.lua.globals().get::>("Owners")?; if message.is_whisper() && let (Some(sender), content) = message.split_sender_and_content() && owners.contains(&sender) @@ -78,20 +77,16 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: } } Event::Init => { - debug!("client initialized"); + debug!("received initialize event"); + let globals = state.lua.globals(); globals.set( "client", lua::client::Client { inner: Some(client), }, )?; - register_functions(&state.lua, &globals, state.clone()).await?; - if let Ok(on_init) = globals.get::("on_init") - && let Err(error) = on_init.call::<()>(()) - { - error!("failed to call lua on_init function: {error:?}"); - } + call_listeners(&state, "init", ()).await; if let Some(address) = state.http_address { let listener = TcpListener::bind(address).await.map_err(|error| { diff --git a/src/lua/events.rs b/src/lua/events.rs index 4dfdfd3..15ab622 100644 --- a/src/lua/events.rs +++ b/src/lua/events.rs @@ -1,17 +1,24 @@ -use crate::State; -use futures::executor::block_on; +use crate::ListenerMap; +use futures::{executor::block_on, lock::Mutex}; use mlua::{Function, Lua, Result, Table}; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; -pub async fn register_functions(lua: &Lua, globals: &Table, state: State) -> Result<()> { - let l = state.event_listeners.clone(); +pub fn register_functions( + lua: &Lua, + globals: &Table, + event_listeners: Arc>, +) -> Result<()> { + let m = event_listeners.clone(); globals.set( "add_listener", lua.create_function( move |_, (event_type, callback, id): (String, Function, Option)| { - let l = l.clone(); + let m = m.clone(); tokio::spawn(async move { - l.lock().await.entry(event_type).or_default().push(( + m.lock().await.entry(event_type).or_default().push(( id.unwrap_or(callback.info().name.unwrap_or(format!( "anonymous @ {}", SystemTime::now() @@ -27,21 +34,21 @@ pub async fn register_functions(lua: &Lua, globals: &Table, state: State) -> Res )?, )?; - let l = state.event_listeners.clone(); + let m = event_listeners.clone(); globals.set( "remove_listener", lua.create_function(move |_, (event_type, target_id): (String, String)| { - let l = l.clone(); + let m = m.clone(); tokio::spawn(async move { - let mut l = l.lock().await; - let empty = if let Some(listeners) = l.get_mut(&event_type) { + let mut m = m.lock().await; + let empty = if let Some(listeners) = m.get_mut(&event_type) { listeners.retain(|(id, _)| target_id != *id); listeners.is_empty() } else { false }; if empty { - l.remove(&event_type); + m.remove(&event_type); } }); Ok(()) @@ -51,10 +58,10 @@ pub async fn register_functions(lua: &Lua, globals: &Table, state: State) -> Res globals.set( "get_listeners", lua.create_function(move |lua, (): ()| { - let l = block_on(state.event_listeners.lock()); + let m = block_on(event_listeners.lock()); let listeners = lua.create_table()?; - for (event_type, callbacks) in l.iter() { + for (event_type, callbacks) in m.iter() { let type_listeners = lua.create_table()?; for (id, callback) in callbacks { let listener = lua.create_table()?; diff --git a/src/lua/mod.rs b/src/lua/mod.rs index 9b7ffc4..b4587c0 100644 --- a/src/lua/mod.rs +++ b/src/lua/mod.rs @@ -8,7 +8,10 @@ pub mod player; pub mod system; pub mod vec3; +use crate::ListenerMap; +use futures::lock::Mutex; use mlua::{Lua, Table}; +use std::{io, sync::Arc, time::Duration}; #[derive(Debug)] #[allow(dead_code)] @@ -18,19 +21,24 @@ pub enum Error { ExecChunk(mlua::Error), LoadChunk(mlua::Error), MissingPath(mlua::Error), - ReadFile(std::io::Error), + ReadFile(io::Error), } -pub fn register_functions(lua: &Lua, globals: &Table) -> mlua::Result<()> { +pub fn register_functions( + lua: &Lua, + globals: &Table, + event_listeners: Arc>, +) -> mlua::Result<()> { globals.set( "sleep", lua.create_async_function(async |_, duration: u64| { - tokio::time::sleep(std::time::Duration::from_millis(duration)).await; + tokio::time::sleep(Duration::from_millis(duration)).await; Ok(()) })?, )?; block::register_functions(lua, globals)?; + events::register_functions(lua, globals, event_listeners)?; logging::register_functions(lua, globals)?; system::register_functions(lua, globals) } diff --git a/src/main.rs b/src/main.rs index 5c008df..42ad5b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,6 @@ use bevy_log::{ }; use clap::Parser; use commands::{CommandSource, register}; -use events::handle_event; use futures::lock::Mutex; use mlua::{Function, Lua}; use std::{ @@ -43,13 +42,14 @@ pub struct State { #[tokio::main] async fn main() -> anyhow::Result<()> { let args = arguments::Arguments::parse(); + let script_path = args.script.unwrap_or(PathBuf::from(DEFAULT_SCRIPT_PATH)); + let event_listeners = Arc::new(Mutex::new(HashMap::new())); let lua = Lua::new(); let globals = lua.globals(); globals.set("script_path", &*script_path)?; - lua::register_functions(&lua, &globals)?; - + lua::register_functions(&lua, &globals, event_listeners.clone())?; lua.load( read_to_string(script_path) .expect(&(DEFAULT_SCRIPT_PATH.to_owned() + " should be in current directory")), @@ -65,28 +65,29 @@ async fn main() -> anyhow::Result<()> { let mut commands = CommandDispatcher::new(); register(&mut commands); + let log_plugin = LogPlugin { + custom_layer: |_| { + env::var("LOG_FILE").ok().map(|log_file| { + layer() + .with_writer( + OpenOptions::new() + .append(true) + .create(true) + .open(log_file) + .expect("should have been able to open log file"), + ) + .boxed() + }) + }, + ..Default::default() + }; let Err(error) = ClientBuilder::new_without_plugins() - .add_plugins(DefaultPlugins.set(LogPlugin { - custom_layer: |_| { - env::var("LOG_FILE").ok().map(|log_file| { - layer() - .with_writer( - OpenOptions::new() - .append(true) - .create(true) - .open(log_file) - .expect("should have been able to open log file"), - ) - .boxed() - }) - }, - ..Default::default() - })) + .add_plugins(DefaultPlugins.set(log_plugin)) .add_plugins(DefaultBotPlugins) - .set_handler(handle_event) + .set_handler(events::handle_event) .set_state(State { lua, - event_listeners: Arc::new(Mutex::new(HashMap::new())), + event_listeners, commands: Arc::new(commands), http_address: args.http_address, })