diff --git a/commands/voice/queue.py b/commands/voice/queue.py index 275920b..7a514a1 100644 --- a/commands/voice/queue.py +++ b/commands/voice/queue.py @@ -180,6 +180,8 @@ async def queue_or_play(message, edited=False): message, f"**{len(players[message.guild.id].queue)}.** {queued.format()}", ) + + utils.cooldown(message, 3) elif tokens[0].lower() == "play": await resume(message) else: diff --git a/core.py b/core.py index 6949f55..72f229b 100644 --- a/core.py +++ b/core.py @@ -15,7 +15,7 @@ import commands import utils from commands import Command as C from constants import EMBED_COLOR, OWNERS, PREFIX, RELOADABLE_MODULES -from state import client, command_locks, idle_tracker +from state import client, command_cooldowns, command_locks, idle_tracker async def on_message(message, edited=False): @@ -40,12 +40,22 @@ async def on_message(message, edited=False): f"ambiguous command, could be {' or '.join([f'`{match.value}`' for match in matched])}", ) return + matched = matched[0] - if message.guild.id not in command_locks: - command_locks[message.guild.id] = asyncio.Lock() + if (message.guild.id, message.author.id) not in command_locks: + command_locks[(message.guild.id, message.author.id)] = asyncio.Lock() + await command_locks[(message.guild.id, message.author.id)].acquire() try: - match matched[0]: + if cooldowns := command_cooldowns.get(message.author.id): + if (end_time := cooldowns.get(matched)) and time.time() < end_time: + await utils.reply( + message, + f"please wait **{round(end_time - time.time(), 1)}s** before using this command again!", + ) + return + + match matched: case C.RELOAD if message.author.id in OWNERS: reloaded_modules = set() start = time.time() @@ -115,11 +125,9 @@ async def on_message(message, edited=False): case C.LEAVE: await commands.voice.leave(message) case C.QUEUE | C.PLAY: - async with command_locks[message.guild.id]: - await commands.voice.queue_or_play(message, edited) + await commands.voice.queue_or_play(message, edited) case C.SKIP: - async with command_locks[message.guild.id]: - await commands.voice.skip(message) + await commands.voice.skip(message) case C.RESUME: await commands.voice.resume(message) case C.PAUSE: @@ -142,6 +150,8 @@ async def on_message(message, edited=False): f"exception occurred while processing command: ```\n{"".join(traceback.format_exception(e)).replace("`", "\\`")}```", ) raise e + finally: + command_locks[(message.guild.id, message.author.id)].release() async def on_voice_state_update(_, before, after): diff --git a/state.py b/state.py index cf7feea..adeef32 100644 --- a/state.py +++ b/state.py @@ -25,6 +25,7 @@ intents.message_content = True intents.members = True client = disnake.Client(intents=intents) +command_cooldowns = LimitedSizeDict() command_locks = LimitedSizeDict() idle_tracker = {"is_idle": False, "last_used": time.time()} kill = {"transcript": False} diff --git a/utils.py b/utils.py index 1e179ce..c52055f 100644 --- a/utils.py +++ b/utils.py @@ -1,10 +1,12 @@ import os +import time from logging import error, info, warning import disnake +import commands import constants -from state import message_responses +from state import command_cooldowns, message_responses class ChannelResponseWrapper: @@ -35,6 +37,19 @@ class MessageInteractionWrapper: await self.response.edit_message(content=content, embed=embed, view=view) +def cooldown(message, cooldown_time: int): + possible_commands = commands.match(message.content) + if not possible_commands or len(possible_commands) > 1: + return + command = possible_commands[0] + + end_time = time.time() + cooldown_time + if message.author.id in command_cooldowns: + command_cooldowns[message.author.id][command] = end_time + else: + command_cooldowns[message.author.id] = {command: end_time} + + def format_duration(duration: int, natural: bool = False, short: bool = False): def format_plural(noun, count): if short: