refactor: use a RwLock for event listeners

This commit is contained in:
2025-02-24 16:52:34 -05:00
parent 36054ced03
commit 247612fad0
6 changed files with 26 additions and 20 deletions

View File

@@ -123,7 +123,7 @@ pub async fn handle_event(client: Client, event: Event, state: State) -> anyhow:
}
async fn call_listeners<T: Clone + IntoLuaMulti>(state: &State, event_type: &str, data: T) {
if let Some(listeners) = state.event_listeners.lock().await.get(event_type) {
if let Some(listeners) = state.event_listeners.read().await.get(event_type) {
for (_, listener) in listeners {
if let Err(error) = listener.call_async::<()>(data.clone()).await {
error!("failed to call lua event listener for {event_type}: {error:?}");

View File

@@ -1,16 +1,9 @@
use crate::ListenerMap;
use futures::{executor::block_on, lock::Mutex};
use futures::executor::block_on;
use mlua::{Function, Lua, Result, Table};
use std::{
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use std::time::{SystemTime, UNIX_EPOCH};
pub fn register_functions(
lua: &Lua,
globals: &Table,
event_listeners: Arc<Mutex<ListenerMap>>,
) -> Result<()> {
pub fn register_functions(lua: &Lua, globals: &Table, event_listeners: ListenerMap) -> Result<()> {
let m = event_listeners.clone();
globals.set(
"add_listener",
@@ -18,7 +11,7 @@ pub fn register_functions(
move |_, (event_type, callback, id): (String, Function, Option<String>)| {
let m = m.clone();
tokio::spawn(async move {
m.lock().await.entry(event_type).or_default().push((
m.write().await.entry(event_type).or_default().push((
id.unwrap_or(callback.info().name.unwrap_or(format!(
"anonymous @ {}",
SystemTime::now()
@@ -40,7 +33,7 @@ pub fn register_functions(
lua.create_function(move |_, (event_type, target_id): (String, String)| {
let m = m.clone();
tokio::spawn(async move {
let mut m = m.lock().await;
let mut m = m.write().await;
let empty = if let Some(listeners) = m.get_mut(&event_type) {
listeners.retain(|(id, _)| target_id != *id);
listeners.is_empty()
@@ -58,7 +51,7 @@ pub fn register_functions(
globals.set(
"get_listeners",
lua.create_function(move |lua, (): ()| {
let m = block_on(event_listeners.lock());
let m = block_on(event_listeners.read());
let listeners = lua.create_table()?;
for (event_type, callbacks) in m.iter() {

View File

@@ -9,9 +9,8 @@ pub mod system;
pub mod vec3;
use crate::ListenerMap;
use futures::lock::Mutex;
use mlua::{Lua, Table};
use std::{io, sync::Arc, time::Duration};
use std::{io, time::Duration};
#[derive(Debug)]
#[allow(dead_code)]
@@ -27,7 +26,7 @@ pub enum Error {
pub fn register_functions(
lua: &Lua,
globals: &Table,
event_listeners: Arc<Mutex<ListenerMap>>,
event_listeners: ListenerMap,
) -> mlua::Result<()> {
globals.set(
"sleep",

View File

@@ -17,6 +17,7 @@ use bevy_log::{
use clap::Parser;
use commands::{CommandSource, register};
use futures::lock::Mutex;
use futures_locks::RwLock;
use mlua::{Function, Lua};
use std::{
collections::HashMap,
@@ -29,12 +30,12 @@ use std::{
const DEFAULT_SCRIPT_PATH: &str = "errornowatcher.lua";
type ListenerMap = HashMap<String, Vec<(String, Function)>>;
type ListenerMap = Arc<RwLock<HashMap<String, Vec<(String, Function)>>>>;
#[derive(Default, Clone, Component)]
pub struct State {
lua: Lua,
event_listeners: Arc<Mutex<ListenerMap>>,
event_listeners: ListenerMap,
commands: Arc<CommandDispatcher<Mutex<CommandSource>>>,
http_address: Option<SocketAddr>,
}
@@ -44,7 +45,7 @@ async fn main() -> anyhow::Result<()> {
let args = arguments::Arguments::parse();
let script_path = args.script.unwrap_or(PathBuf::from(DEFAULT_SCRIPT_PATH));
let event_listeners = Arc::new(Mutex::new(HashMap::new()));
let event_listeners = Arc::new(RwLock::new(HashMap::new()));
let lua = Lua::new();
let globals = lua.globals();