feat(commands): add cooldown system

This commit is contained in:
Ryan 2025-01-08 10:17:06 -05:00
parent 7c2e17e0d3
commit 672ae02e16
Signed by: ErrorNoInternet
GPG Key ID: 2486BFB7B1E6A4A3
4 changed files with 37 additions and 9 deletions

View File

@ -180,6 +180,8 @@ async def queue_or_play(message, edited=False):
message, message,
f"**{len(players[message.guild.id].queue)}.** {queued.format()}", f"**{len(players[message.guild.id].queue)}.** {queued.format()}",
) )
utils.cooldown(message, 3)
elif tokens[0].lower() == "play": elif tokens[0].lower() == "play":
await resume(message) await resume(message)
else: else:

26
core.py
View File

@ -15,7 +15,7 @@ import commands
import utils import utils
from commands import Command as C from commands import Command as C
from constants import EMBED_COLOR, OWNERS, PREFIX, RELOADABLE_MODULES from constants import EMBED_COLOR, OWNERS, PREFIX, RELOADABLE_MODULES
from state import client, command_locks, idle_tracker from state import client, command_cooldowns, command_locks, idle_tracker
async def on_message(message, edited=False): async def on_message(message, edited=False):
@ -40,12 +40,22 @@ async def on_message(message, edited=False):
f"ambiguous command, could be {' or '.join([f'`{match.value}`' for match in matched])}", f"ambiguous command, could be {' or '.join([f'`{match.value}`' for match in matched])}",
) )
return return
matched = matched[0]
if message.guild.id not in command_locks: if (message.guild.id, message.author.id) not in command_locks:
command_locks[message.guild.id] = asyncio.Lock() command_locks[(message.guild.id, message.author.id)] = asyncio.Lock()
await command_locks[(message.guild.id, message.author.id)].acquire()
try: try:
match matched[0]: if cooldowns := command_cooldowns.get(message.author.id):
if (end_time := cooldowns.get(matched)) and time.time() < end_time:
await utils.reply(
message,
f"please wait **{round(end_time - time.time(), 1)}s** before using this command again!",
)
return
match matched:
case C.RELOAD if message.author.id in OWNERS: case C.RELOAD if message.author.id in OWNERS:
reloaded_modules = set() reloaded_modules = set()
start = time.time() start = time.time()
@ -115,11 +125,9 @@ async def on_message(message, edited=False):
case C.LEAVE: case C.LEAVE:
await commands.voice.leave(message) await commands.voice.leave(message)
case C.QUEUE | C.PLAY: case C.QUEUE | C.PLAY:
async with command_locks[message.guild.id]: await commands.voice.queue_or_play(message, edited)
await commands.voice.queue_or_play(message, edited)
case C.SKIP: case C.SKIP:
async with command_locks[message.guild.id]: await commands.voice.skip(message)
await commands.voice.skip(message)
case C.RESUME: case C.RESUME:
await commands.voice.resume(message) await commands.voice.resume(message)
case C.PAUSE: case C.PAUSE:
@ -142,6 +150,8 @@ async def on_message(message, edited=False):
f"exception occurred while processing command: ```\n{"".join(traceback.format_exception(e)).replace("`", "\\`")}```", f"exception occurred while processing command: ```\n{"".join(traceback.format_exception(e)).replace("`", "\\`")}```",
) )
raise e raise e
finally:
command_locks[(message.guild.id, message.author.id)].release()
async def on_voice_state_update(_, before, after): async def on_voice_state_update(_, before, after):

View File

@ -25,6 +25,7 @@ intents.message_content = True
intents.members = True intents.members = True
client = disnake.Client(intents=intents) client = disnake.Client(intents=intents)
command_cooldowns = LimitedSizeDict()
command_locks = LimitedSizeDict() command_locks = LimitedSizeDict()
idle_tracker = {"is_idle": False, "last_used": time.time()} idle_tracker = {"is_idle": False, "last_used": time.time()}
kill = {"transcript": False} kill = {"transcript": False}

View File

@ -1,10 +1,12 @@
import os import os
import time
from logging import error, info, warning from logging import error, info, warning
import disnake import disnake
import commands
import constants import constants
from state import message_responses from state import command_cooldowns, message_responses
class ChannelResponseWrapper: class ChannelResponseWrapper:
@ -35,6 +37,19 @@ class MessageInteractionWrapper:
await self.response.edit_message(content=content, embed=embed, view=view) await self.response.edit_message(content=content, embed=embed, view=view)
def cooldown(message, cooldown_time: int):
possible_commands = commands.match(message.content)
if not possible_commands or len(possible_commands) > 1:
return
command = possible_commands[0]
end_time = time.time() + cooldown_time
if message.author.id in command_cooldowns:
command_cooldowns[message.author.id][command] = end_time
else:
command_cooldowns[message.author.id] = {command: end_time}
def format_duration(duration: int, natural: bool = False, short: bool = False): def format_duration(duration: int, natural: bool = False, short: bool = False):
def format_plural(noun, count): def format_plural(noun, count):
if short: if short: