Compare commits

..

20 Commits

Author SHA1 Message Date
MightyCoderX
dfe05cc548
chore: dockerize bot (#1)
Co-authored-by: ErrorNoInternet <errornointernet@envs.net>
2025-01-06 16:12:56 -05:00
eeca6ec5d9
fix: manually load opus if not loaded already 2025-01-06 16:11:27 -05:00
b9e5f1899e
refactor(commands/voice): throw error if playback fails 2025-01-06 15:55:16 -05:00
cf98497c99
refactor(commands/voice): tweak error message 2025-01-06 15:08:57 -05:00
71be016461
test(format_duration): youtubedl duration formatting 2025-01-06 14:45:37 -05:00
a5503751a5
fix(youtubedl): always cast to int when formatting duration 2025-01-06 14:41:43 -05:00
da5db1e73a
feat(utils): use improved disnake_paginator interaction wrapper
Adds edit support (via utils.reply)
2025-01-06 14:31:08 -05:00
729fc28f1b
test: add format duration tests 2025-01-06 14:06:06 -05:00
439095116f
feat(commands/voice): add ff 2025-01-06 13:43:55 -05:00
f06d8075ea
feat(youtubedl): add fast_forward 2025-01-06 13:00:33 -05:00
7c4041c662
feat(commands/voice/queue): allow removing multiple indices 2025-01-06 12:46:46 -05:00
5333559b25
feat(utils/format_duration): add natural 2025-01-06 12:08:34 -05:00
74629ad984
feat(commands/voice/queue): add immutability checks 2025-01-06 11:59:43 -05:00
b0e378105e
feat(commands/voice): allow immutable commands from non-vc members 2025-01-06 11:53:19 -05:00
290e85a1c1
fix(commands/voice): check if player exists before playing next 2025-01-06 11:53:02 -05:00
42735f9a60
refactor: tweak duration bracket bolding 2025-01-06 11:36:27 -05:00
c0173b87e9
fix: don't edit on channel send 2025-01-06 11:14:49 -05:00
d3fd79e87f
feat(commands/voice/playing): add --description 2025-01-06 10:45:38 -05:00
d9d35a2672
refactor(constants): sort variables 2025-01-06 10:20:14 -05:00
6887ebe087
refactor: tweak some descriptions 2025-01-06 10:12:58 -05:00
9 changed files with 316 additions and 116 deletions

10
Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM python:3.13-alpine
RUN apk --no-cache add ffmpeg opus
WORKDIR /bot
COPY . .
RUN pip install -r requirements.txt
CMD ["python", "-OO", "main.py"]

View File

@ -6,6 +6,7 @@ import constants
class Command(enum.Enum):
CLEAR = "clear"
EXECUTE = "execute"
FAST_FORWARD = "ff"
HELP = "help"
JOIN = "join"
LEAVE = "leave"

View File

@ -12,10 +12,6 @@ from state import client, players
async def queue_or_play(message, edited=False):
await ensure_joined(message)
if not command_allowed(message):
return
if message.guild.id not in players:
players[message.guild.id] = youtubedl.QueuedPlayer()
@ -35,11 +31,12 @@ async def queue_or_play(message, edited=False):
"-i",
"--remove-index",
type=int,
help="remove a queued song by index",
nargs="*",
help="remove queued songs by index",
)
parser.add_argument(
"-m",
"--remove-multiple",
"--match-multiple",
action="store_true",
help="continue removing queued after finding a match",
)
@ -73,6 +70,13 @@ async def queue_or_play(message, edited=False):
if not (args := await parser.parse_args(message, tokens)):
return
await ensure_joined(message)
if len(tokens) == 1 and tokens[0].lower() != "play":
if not command_allowed(message, immutable=True):
return
elif not command_allowed(message):
return
if edited:
found = None
for queued in players[message.guild.id].queue:
@ -86,14 +90,25 @@ async def queue_or_play(message, edited=False):
players[message.guild.id].queue.clear()
await utils.add_check_reaction(message)
return
elif i := args.remove_index:
elif indices := args.remove_index:
targets = []
for i in indices:
if i <= 0 or i > len(players[message.guild.id].queue):
await utils.reply(message, "invalid index!")
await utils.reply(message, f"invalid index `{i}`!")
return
targets.append(players[message.guild.id].queue[i - 1])
queued = players[message.guild.id].queue[i - 1]
del players[message.guild.id].queue[i - 1]
await utils.reply(message, f"**X** {queued.format()}")
for target in targets:
if target in players[message.guild.id].queue:
players[message.guild.id].queue.remove(target)
if len(targets) == 1:
await utils.reply(message, f"**X** {targets[0].format()}")
else:
await utils.reply(
message,
f"removed **{len(targets)}** queued {'song' if len(targets) == 1 else 'songs'}",
)
elif args.remove_title or args.remove_queuer:
targets = []
for queued in players[message.guild.id].queue:
@ -104,7 +119,7 @@ async def queue_or_play(message, edited=False):
if q := args.remove_queuer:
if q == queued.trigger_message.author.id:
targets.append(queued)
if not args.remove_multiple:
if not args.match_multiple:
targets = targets[:1]
for target in targets:
@ -162,8 +177,7 @@ async def queue_or_play(message, edited=False):
message,
f"**{len(players[message.guild.id].queue)}.** {queued.format()}",
)
else:
if tokens[0].lower() == "play":
elif tokens[0].lower() == "play":
await resume(message)
else:
if players[message.guild.id].queue:
@ -173,7 +187,8 @@ async def queue_or_play(message, edited=False):
queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue
]
)
),
natural=True,
)
def embed(description):
@ -182,7 +197,7 @@ async def queue_or_play(message, edited=False):
color=constants.EMBED_COLOR,
)
if formatted_duration:
e.set_footer(text=f"{formatted_duration} long")
e.set_footer(text=f"{formatted_duration} in total")
return e
await disnake_paginator.ButtonPaginator(
@ -204,7 +219,7 @@ async def queue_or_play(message, edited=False):
],
)
),
).start(disnake_paginator.wrappers.MessageInteractionWrapper(message))
).start(utils.MessageInteractionWrapper(message))
else:
await utils.reply(
message,
@ -213,20 +228,51 @@ async def queue_or_play(message, edited=False):
async def playing(message):
if not command_allowed(message):
if not command_allowed(message, immutable=True):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0], "get information about the currently playing song"
)
parser.add_argument(
"-d",
"--description",
action="store_true",
help="get the description",
)
if not (args := await parser.parse_args(message, tokens)):
return
if source := message.guild.voice_client.source:
if args.description:
if description := source.description:
paginator = disnake_paginator.ButtonPaginator(
invalid_user_function=utils.invalid_user_handler,
color=constants.EMBED_COLOR,
title=source.title,
segments=disnake_paginator.split(description),
)
for embed in paginator.embeds:
embed.url = source.original_url
await paginator.start(utils.MessageInteractionWrapper(message))
else:
await utils.reply(
message,
source.description or "no description found!",
)
return
bar_length = 35
progress = source.original.progress / source.duration
embed = disnake.Embed(
color=constants.EMBED_COLOR,
title=source.title,
url=source.original_url,
description=f"{'⏸️ ' if message.guild.voice_client.is_paused() else ''}"
f"`[{'#'*int(progress * bar_length)}{'-'*int((1 - progress) * bar_length)}]` "
f"{youtubedl.format_duration(int(source.original.progress))} / {youtubedl.format_duration(source.duration)} ({round(progress * 100)}%)",
url=source.original_url,
f"**{youtubedl.format_duration(int(source.original.progress))}** / **{youtubedl.format_duration(source.duration)}** (**{round(progress * 100)}%**)",
)
embed.add_field(name="Volume", value=f"{int(source.volume*100)}%")
embed.add_field(name="Views", value=f"{source.view_count:,}")
@ -247,6 +293,31 @@ async def playing(message):
)
async def fast_forward(message):
if not command_allowed(message):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "fast forward audio playback")
parser.add_argument(
"seconds",
type=lambda v: arguments.range_type(v, min=0, max=300),
help="the amount of seconds to fast forward",
)
if not (args := await parser.parse_args(message, tokens)):
return
if not message.guild.voice_client.source:
await utils.reply(message, "nothing is playing!")
return
message.guild.voice_client.pause()
message.guild.voice_client.source.original.fast_forward(args.seconds)
message.guild.voice_client.resume()
await utils.add_check_reaction(message)
async def skip(message):
if not command_allowed(message):
return
@ -309,11 +380,11 @@ async def pause(message):
async def volume(message):
if not command_allowed(message):
if not command_allowed(message, immutable=True):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "set the current volume level")
parser = arguments.ArgumentParser(tokens[0], "get or set the current volume level")
parser.add_argument(
"volume",
nargs="?",
@ -333,6 +404,9 @@ async def volume(message):
f"{int(message.guild.voice_client.source.volume * 100)}",
)
else:
if not command_allowed(message):
return
message.guild.voice_client.source.volume = float(args.volume) / 100.0
await utils.add_check_reaction(message)
@ -362,17 +436,17 @@ def play_after_callback(e, message, once):
def play_next(message, once=False, first=False):
message.guild.voice_client.stop()
if players[message.guild.id].queue:
if message.guild.id in players and players[message.guild.id].queue:
queued = players[message.guild.id].queue_pop()
try:
message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once)
)
except Exception as e:
client.loop.create_task(
utils.channel_send(message, f"error while trying to play: `{e}`")
except disnake.opus.OpusNotLoaded:
utils.load_opus()
message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once)
)
return
client.loop.create_task(
utils.channel_send(message, queued.format(show_queuer=not first))
)
@ -383,10 +457,15 @@ async def ensure_joined(message):
if message.author.voice:
await message.author.voice.channel.connect()
else:
await utils.reply(message, "You are not connected to a voice channel.")
await utils.reply(message, "you are not connected to a voice channel!")
def command_allowed(message):
if not message.author.voice or not message.guild.voice_client:
def command_allowed(message, immutable=False):
if not message.guild.voice_client:
return
if immutable:
return message.channel.id == message.guild.voice_client.channel.id
else:
if not message.author.voice:
return False
return message.author.voice.channel.id == message.guild.voice_client.channel.id

View File

@ -1,5 +1,20 @@
import os
YTDL_OPTIONS = {
"color": "never",
"default_search": "auto",
"format": "bestaudio/best",
"ignoreerrors": False,
"logtostderr": False,
"no_warnings": True,
"noplaylist": True,
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
"quiet": True,
"restrictfilenames": True,
"socket_timeout": 15,
"source_address": "0.0.0.0",
}
EMBED_COLOR = 0xFF6600
OWNERS = [531392146767347712]
PREFIX = "%"
@ -19,22 +34,6 @@ RELOADABLE_MODULES = [
"youtubedl",
]
YTDL_OPTIONS = {
"color": "never",
"default_search": "auto",
"format": "bestaudio/best",
"ignoreerrors": False,
"logtostderr": False,
"no_warnings": True,
"noplaylist": True,
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
"quiet": True,
"restrictfilenames": True,
"socket_timeout": 15,
"source_address": "0.0.0.0",
}
SECRETS = {
"TOKEN": os.getenv("BOT_TOKEN"),
}

View File

@ -47,6 +47,7 @@ async def on_message(message, edited=False):
match matched[0]:
case C.RELOAD if message.author.id in constants.OWNERS:
reloaded_modules = set()
rreload(reloaded_modules, __import__("core"))
for module in filter(
lambda v: inspect.ismodule(v)
and v.__name__ in constants.RELOADABLE_MODULES,
@ -92,13 +93,11 @@ async def on_message(message, edited=False):
invalid_user_function=utils.invalid_user_handler,
color=constants.EMBED_COLOR,
segments=disnake_paginator.split(output),
).start(
disnake_paginator.wrappers.MessageInteractionWrapper(message)
)
).start(utils.MessageInteractionWrapper(message))
elif len(output.strip()) == 0:
await utils.add_check_reaction(message)
else:
await utils.channel_send(message, output)
await utils.reply(message, output)
case C.CLEAR | C.PURGE if message.author.id in constants.OWNERS:
await commands.tools.clear(message)
case C.JOIN:
@ -123,6 +122,8 @@ async def on_message(message, edited=False):
await commands.bot.uptime(message)
case C.PLAYING:
await commands.voice.playing(message)
case C.FAST_FORWARD:
await commands.voice.fast_forward(message)
except Exception as e:
await utils.reply(
message,

3
tests/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from . import test_format_duration
__all__ = ["test_format_duration"]

View File

@ -0,0 +1,61 @@
import unittest
import utils
import youtubedl
class TestFormatDuration(unittest.TestCase):
def test_youtubedl(self):
self.assertEqual(youtubedl.format_duration(0), "00:00")
self.assertEqual(youtubedl.format_duration(0.5), "00:00")
self.assertEqual(youtubedl.format_duration(60.5), "01:00")
self.assertEqual(youtubedl.format_duration(1), "00:01")
self.assertEqual(youtubedl.format_duration(60), "01:00")
self.assertEqual(youtubedl.format_duration(60 + 30), "01:30")
self.assertEqual(youtubedl.format_duration(60 * 60), "01:00:00")
self.assertEqual(youtubedl.format_duration(60 * 60 + 30), "01:00:30")
def test_utils(self):
self.assertEqual(utils.format_duration(0), "")
self.assertEqual(utils.format_duration(60 * 60 * 24 * 7), "1 week")
self.assertEqual(utils.format_duration(60 * 60 * 24 * 21), "3 weeks")
self.assertEqual(
utils.format_duration((60 * 60 * 24 * 21) - 1),
"2 weeks, 6 days, 23 hours, 59 minutes, 59 seconds",
)
self.assertEqual(utils.format_duration(60), "1 minute")
self.assertEqual(utils.format_duration(60 * 2), "2 minutes")
self.assertEqual(utils.format_duration(60 * 59), "59 minutes")
self.assertEqual(utils.format_duration(60 * 60), "1 hour")
self.assertEqual(utils.format_duration(60 * 60 * 2), "2 hours")
self.assertEqual(utils.format_duration(1), "1 second")
self.assertEqual(utils.format_duration(60 + 5), "1 minute, 5 seconds")
self.assertEqual(utils.format_duration(60 * 60 + 30), "1 hour, 30 seconds")
self.assertEqual(
utils.format_duration(60 * 60 + 60 + 30), "1 hour, 1 minute, 30 seconds"
)
self.assertEqual(
utils.format_duration(60 * 60 * 24 * 7 + 30), "1 week, 30 seconds"
)
def test_utils_natural(self):
def format(seconds: int):
return utils.format_duration(seconds, natural=True)
self.assertEqual(format(0), "")
self.assertEqual(format(60 * 60 * 24 * 7), "1 week")
self.assertEqual(format(60 * 60 * 24 * 21), "3 weeks")
self.assertEqual(
format((60 * 60 * 24 * 21) - 1),
"2 weeks, 6 days, 23 hours, 59 minutes and 59 seconds",
)
self.assertEqual(format(60), "1 minute")
self.assertEqual(format(60 * 2), "2 minutes")
self.assertEqual(format(60 * 59), "59 minutes")
self.assertEqual(format(60 * 60), "1 hour")
self.assertEqual(format(60 * 60 * 2), "2 hours")
self.assertEqual(format(1), "1 second")
self.assertEqual(format(60 + 5), "1 minute and 5 seconds")
self.assertEqual(format(60 * 60 + 30), "1 hour and 30 seconds")
self.assertEqual(format(60 * 60 + 60 + 30), "1 hour, 1 minute and 30 seconds")
self.assertEqual(format(60 * 60 * 24 * 7 + 30), "1 week and 30 seconds")

View File

@ -1,10 +1,40 @@
import os
import disnake
import constants
from state import message_responses
def format_duration(duration: int):
class ChannelResponseWrapper:
def __init__(self, message):
self.message = message
self.sent_message = None
async def send_message(self, **kwargs):
if "ephemeral" in kwargs:
del kwargs["ephemeral"]
self.sent_message = await reply(self.message, **kwargs)
async def edit_message(self, content=None, embed=None, view=None):
if self.sent_message:
content = content or self.sent_message.content
if not embed and len(self.sent_message.embeds) > 0:
embed = self.sent_message.embeds[0]
await self.sent_message.edit(content=content, embed=embed, view=view)
class MessageInteractionWrapper:
def __init__(self, message):
self.message = message
self.author = message.author
self.response = ChannelResponseWrapper(message)
async def edit_original_message(self, content=None, embed=None, view=None):
await self.response.edit_message(content=content, embed=embed, view=view)
def format_duration(duration: int, natural: bool = False):
def format_plural(noun, count):
return noun if count == 1 else noun + "s"
@ -29,8 +59,11 @@ def format_duration(duration: int):
if duration > 0:
segments.append(f"{duration} {format_plural('second', duration)}")
if not natural or len(segments) <= 1:
return ", ".join(segments)
return ", ".join(segments[:-1]) + f" and {segments[-1]}"
async def add_check_reaction(message):
await message.add_reaction("")
@ -46,23 +79,18 @@ async def reply(message, *args, **kwargs):
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
message_responses[message.id] = response
return message_responses[message.id]
async def channel_send(message, *args, **kwargs):
if message.id in message_responses:
await message_responses[message.id].edit(
await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
else:
response = await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
message_responses[message.id] = response
async def invalid_user_handler(interaction):
await interaction.response.send_message(
"You are not the intended receiver of this message!", ephemeral=True
"you are not the intended receiver of this message!", ephemeral=True
)
@ -72,3 +100,16 @@ def filter_secrets(text: str) -> str:
continue
text = text.replace(secret, f"<{secret_name}>")
return text
def load_opus():
print("opus wasn't automatically loaded! trying to load manually...")
for path in ["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"]:
if os.path.exists(path):
try:
disnake.opus.load_opus(path)
print(f"successfully loaded opus from {path}")
return
except Exception as e:
print(f"failed to load opus from {path}: {e}")
raise Exception("could not locate working opus library")

View File

@ -11,28 +11,33 @@ import constants
ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS)
class TrackedAudioSource(disnake.AudioSource):
class CustomAudioSource(disnake.AudioSource):
def __init__(self, source):
self._source = source
self.count = 0
self.read_count = 0
def read(self) -> bytes:
data = self._source.read()
if data:
self.count += 1
self.read_count += 1
return data
def fast_forward(self, seconds: int):
for _ in range(int(seconds / 0.02)):
self.read()
@property
def progress(self) -> float:
return self.count * 0.02
return self.read_count * 0.02
class YTDLSource(disnake.PCMVolumeTransformer):
def __init__(
self, source: TrackedAudioSource, *, data: dict[str, Any], volume: float = 0.5
self, source: CustomAudioSource, *, data: dict[str, Any], volume: float = 0.5
):
super().__init__(source, volume)
self.description = data.get("description")
self.duration = data.get("duration")
self.original_url = data.get("original_url")
self.thumbnail_url = data.get("thumbnail")
@ -56,7 +61,7 @@ class YTDLSource(disnake.PCMVolumeTransformer):
data = data["entries"][0]
return cls(
TrackedAudioSource(
CustomAudioSource(
disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1",
@ -82,14 +87,14 @@ class QueuedSong:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})\n**duration:** {format_duration(self.player.duration) if self.player.duration else '[live]'}"
+ (
f", **queuer:** <@{self.trigger_message.author.id}>"
f", **queued by:** <@{self.trigger_message.author.id}>"
if show_queuer
else ""
)
)
else:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) **[{format_duration(self.player.duration) if self.player.duration else 'live'}]**"
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) [**{format_duration(self.player.duration) if self.player.duration else 'live'}**]"
+ (f" (<@{self.trigger_message.author.id}>)" if show_queuer else "")
)
@ -117,8 +122,8 @@ class QueuedPlayer:
return self.__repr__()
def format_duration(duration: int) -> str:
hours, duration = divmod(duration, 3600)
def format_duration(duration: int | float) -> str:
hours, duration = divmod(int(duration), 3600)
minutes, duration = divmod(duration, 60)
segments = [hours, minutes, duration]
if len(segments) == 3 and segments[0] == 0: