diff --git a/Cargo.lock b/Cargo.lock index c52085d..c0baeb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1184,6 +1184,7 @@ dependencies = [ "bevy_log", "clap", "futures", + "futures-locks", "http-body-util", "hyper", "hyper-util", @@ -1317,6 +1318,17 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "futures-locks" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45ec6fe3675af967e67c5536c0b9d44e34e6c52f86bedc4ea49c5317b8e94d06" +dependencies = [ + "futures-channel", + "futures-task", + "tokio", +] + [[package]] name = "futures-macro" version = "0.3.31" diff --git a/Cargo.toml b/Cargo.toml index 3d9f48a..0e84dee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ bevy_app = "0" bevy_log = "0" clap = { version = "4", features = ["derive"] } futures = "0" +futures-locks = "0" http-body-util = "0" hyper = { version = "1", features = ["server"] } hyper-util = "0" diff --git a/src/events.rs b/src/events.rs index ec2bea8..5db01e3 100644 --- a/src/events.rs +++ b/src/events.rs @@ -123,7 +123,7 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow: } async fn call_listeners(state: &State, event_type: &str, data: T) { - if let Some(listeners) = state.event_listeners.lock().await.get(event_type) { + if let Some(listeners) = state.event_listeners.read().await.get(event_type) { for (_, listener) in listeners { if let Err(error) = listener.call_async::<()>(data.clone()).await { error!("failed to call lua event listener for {event_type}: {error:?}"); diff --git a/src/lua/events.rs b/src/lua/events.rs index 15ab622..bd2a8d0 100644 --- a/src/lua/events.rs +++ b/src/lua/events.rs @@ -1,16 +1,9 @@ use crate::ListenerMap; -use futures::{executor::block_on, lock::Mutex}; +use futures::executor::block_on; use mlua::{Function, Lua, Result, Table}; -use std::{ - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::time::{SystemTime, UNIX_EPOCH}; -pub fn register_functions( - lua: &Lua, - globals: &Table, - event_listeners: Arc>, -) -> Result<()> { +pub fn register_functions(lua: &Lua, globals: &Table, event_listeners: ListenerMap) -> Result<()> { let m = event_listeners.clone(); globals.set( "add_listener", @@ -18,7 +11,7 @@ pub fn register_functions( move |_, (event_type, callback, id): (String, Function, Option)| { let m = m.clone(); tokio::spawn(async move { - m.lock().await.entry(event_type).or_default().push(( + m.write().await.entry(event_type).or_default().push(( id.unwrap_or(callback.info().name.unwrap_or(format!( "anonymous @ {}", SystemTime::now() @@ -40,7 +33,7 @@ pub fn register_functions( lua.create_function(move |_, (event_type, target_id): (String, String)| { let m = m.clone(); tokio::spawn(async move { - let mut m = m.lock().await; + let mut m = m.write().await; let empty = if let Some(listeners) = m.get_mut(&event_type) { listeners.retain(|(id, _)| target_id != *id); listeners.is_empty() @@ -58,7 +51,7 @@ pub fn register_functions( globals.set( "get_listeners", lua.create_function(move |lua, (): ()| { - let m = block_on(event_listeners.lock()); + let m = block_on(event_listeners.read()); let listeners = lua.create_table()?; for (event_type, callbacks) in m.iter() { diff --git a/src/lua/mod.rs b/src/lua/mod.rs index b4587c0..8355834 100644 --- a/src/lua/mod.rs +++ b/src/lua/mod.rs @@ -9,9 +9,8 @@ pub mod system; pub mod vec3; use crate::ListenerMap; -use futures::lock::Mutex; use mlua::{Lua, Table}; -use std::{io, sync::Arc, time::Duration}; +use std::{io, time::Duration}; #[derive(Debug)] #[allow(dead_code)] @@ -27,7 +26,7 @@ pub enum Error { pub fn register_functions( lua: &Lua, globals: &Table, - event_listeners: Arc>, + event_listeners: ListenerMap, ) -> mlua::Result<()> { globals.set( "sleep", diff --git a/src/main.rs b/src/main.rs index 42ad5b0..c735df7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ use bevy_log::{ use clap::Parser; use commands::{CommandSource, register}; use futures::lock::Mutex; +use futures_locks::RwLock; use mlua::{Function, Lua}; use std::{ collections::HashMap, @@ -29,12 +30,12 @@ use std::{ const DEFAULT_SCRIPT_PATH: &str = "errornowatcher.lua"; -type ListenerMap = HashMap>; +type ListenerMap = Arc>>>; #[derive(Default, Clone, Component)] pub struct State { lua: Lua, - event_listeners: Arc>, + event_listeners: ListenerMap, commands: Arc>>, http_address: Option, } @@ -44,7 +45,7 @@ async fn main() -> anyhow::Result<()> { let args = arguments::Arguments::parse(); let script_path = args.script.unwrap_or(PathBuf::from(DEFAULT_SCRIPT_PATH)); - let event_listeners = Arc::new(Mutex::new(HashMap::new())); + let event_listeners = Arc::new(RwLock::new(HashMap::new())); let lua = Lua::new(); let globals = lua.globals();