Compare commits

..

No commits in common. "dfe05cc548682ad8f006557fe53dc93048db8480" and "5216d611c3a04a99b271e9389f162c99cdf1ae0b" have entirely different histories.

9 changed files with 115 additions and 315 deletions

View File

@ -1,10 +0,0 @@
FROM python:3.13-alpine
RUN apk --no-cache add ffmpeg opus
WORKDIR /bot
COPY . .
RUN pip install -r requirements.txt
CMD ["python", "-OO", "main.py"]

View File

@ -6,7 +6,6 @@ import constants
class Command(enum.Enum):
CLEAR = "clear"
EXECUTE = "execute"
FAST_FORWARD = "ff"
HELP = "help"
JOIN = "join"
LEAVE = "leave"

View File

@ -12,6 +12,10 @@ from state import client, players
async def queue_or_play(message, edited=False):
await ensure_joined(message)
if not command_allowed(message):
return
if message.guild.id not in players:
players[message.guild.id] = youtubedl.QueuedPlayer()
@ -31,12 +35,11 @@ async def queue_or_play(message, edited=False):
"-i",
"--remove-index",
type=int,
nargs="*",
help="remove queued songs by index",
help="remove a queued song by index",
)
parser.add_argument(
"-m",
"--match-multiple",
"--remove-multiple",
action="store_true",
help="continue removing queued after finding a match",
)
@ -70,13 +73,6 @@ async def queue_or_play(message, edited=False):
if not (args := await parser.parse_args(message, tokens)):
return
await ensure_joined(message)
if len(tokens) == 1 and tokens[0].lower() != "play":
if not command_allowed(message, immutable=True):
return
elif not command_allowed(message):
return
if edited:
found = None
for queued in players[message.guild.id].queue:
@ -90,25 +86,14 @@ async def queue_or_play(message, edited=False):
players[message.guild.id].queue.clear()
await utils.add_check_reaction(message)
return
elif indices := args.remove_index:
targets = []
for i in indices:
if i <= 0 or i > len(players[message.guild.id].queue):
await utils.reply(message, f"invalid index `{i}`!")
return
targets.append(players[message.guild.id].queue[i - 1])
elif i := args.remove_index:
if i <= 0 or i > len(players[message.guild.id].queue):
await utils.reply(message, "invalid index!")
return
for target in targets:
if target in players[message.guild.id].queue:
players[message.guild.id].queue.remove(target)
if len(targets) == 1:
await utils.reply(message, f"**X** {targets[0].format()}")
else:
await utils.reply(
message,
f"removed **{len(targets)}** queued {'song' if len(targets) == 1 else 'songs'}",
)
queued = players[message.guild.id].queue[i - 1]
del players[message.guild.id].queue[i - 1]
await utils.reply(message, f"**X** {queued.format()}")
elif args.remove_title or args.remove_queuer:
targets = []
for queued in players[message.guild.id].queue:
@ -119,7 +104,7 @@ async def queue_or_play(message, edited=False):
if q := args.remove_queuer:
if q == queued.trigger_message.author.id:
targets.append(queued)
if not args.match_multiple:
if not args.remove_multiple:
targets = targets[:1]
for target in targets:
@ -177,102 +162,71 @@ async def queue_or_play(message, edited=False):
message,
f"**{len(players[message.guild.id].queue)}.** {queued.format()}",
)
elif tokens[0].lower() == "play":
await resume(message)
else:
if players[message.guild.id].queue:
formatted_duration = utils.format_duration(
sum(
[
queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue
]
),
natural=True,
)
def embed(description):
e = disnake.Embed(
description=description,
color=constants.EMBED_COLOR,
)
if formatted_duration:
e.set_footer(text=f"{formatted_duration} in total")
return e
await disnake_paginator.ButtonPaginator(
invalid_user_function=utils.invalid_user_handler,
color=constants.EMBED_COLOR,
segments=list(
map(
embed,
[
"\n\n".join(
[
f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}"
for i, queued in batch
]
)
for batch in itertools.batched(
enumerate(players[message.guild.id].queue), 10
)
],
)
),
).start(utils.MessageInteractionWrapper(message))
if tokens[0].lower() == "play":
await resume(message)
else:
await utils.reply(
message,
"nothing is queued!",
)
if players[message.guild.id].queue:
formatted_duration = utils.format_duration(
sum(
[
queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue
]
)
)
def embed(description):
e = disnake.Embed(
description=description,
color=constants.EMBED_COLOR,
)
if formatted_duration:
e.set_footer(text=f"{formatted_duration} long")
return e
async def playing(message):
if not command_allowed(message, immutable=True):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0], "get information about the currently playing song"
)
parser.add_argument(
"-d",
"--description",
action="store_true",
help="get the description",
)
if not (args := await parser.parse_args(message, tokens)):
return
if source := message.guild.voice_client.source:
if args.description:
if description := source.description:
paginator = disnake_paginator.ButtonPaginator(
await disnake_paginator.ButtonPaginator(
invalid_user_function=utils.invalid_user_handler,
color=constants.EMBED_COLOR,
title=source.title,
segments=disnake_paginator.split(description),
)
for embed in paginator.embeds:
embed.url = source.original_url
await paginator.start(utils.MessageInteractionWrapper(message))
segments=list(
map(
embed,
[
"\n\n".join(
[
f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}"
for i, queued in batch
]
)
for batch in itertools.batched(
enumerate(players[message.guild.id].queue), 10
)
],
)
),
).start(disnake_paginator.wrappers.MessageInteractionWrapper(message))
else:
await utils.reply(
message,
source.description or "no description found!",
"nothing is queued!",
)
return
async def playing(message):
if not command_allowed(message):
return
if source := message.guild.voice_client.source:
bar_length = 35
progress = source.original.progress / source.duration
embed = disnake.Embed(
color=constants.EMBED_COLOR,
title=source.title,
url=source.original_url,
description=f"{'⏸️ ' if message.guild.voice_client.is_paused() else ''}"
f"`[{'#'*int(progress * bar_length)}{'-'*int((1 - progress) * bar_length)}]` "
f"**{youtubedl.format_duration(int(source.original.progress))}** / **{youtubedl.format_duration(source.duration)}** (**{round(progress * 100)}%**)",
f"`[{'#'*int(progress * bar_length)}{'-'*int((1 - progress) * bar_length)}]`"
f"{youtubedl.format_duration(int(source.original.progress))} / {youtubedl.format_duration(source.duration)} ({round(progress * 100)}%)",
url=source.original_url,
)
embed.add_field(name="Volume", value=f"{int(source.volume*100)}%")
embed.add_field(name="Views", value=f"{source.view_count:,}")
@ -293,31 +247,6 @@ async def playing(message):
)
async def fast_forward(message):
if not command_allowed(message):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "fast forward audio playback")
parser.add_argument(
"seconds",
type=lambda v: arguments.range_type(v, min=0, max=300),
help="the amount of seconds to fast forward",
)
if not (args := await parser.parse_args(message, tokens)):
return
if not message.guild.voice_client.source:
await utils.reply(message, "nothing is playing!")
return
message.guild.voice_client.pause()
message.guild.voice_client.source.original.fast_forward(args.seconds)
message.guild.voice_client.resume()
await utils.add_check_reaction(message)
async def skip(message):
if not command_allowed(message):
return
@ -380,11 +309,11 @@ async def pause(message):
async def volume(message):
if not command_allowed(message, immutable=True):
if not command_allowed(message):
return
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "get or set the current volume level")
parser = arguments.ArgumentParser(tokens[0], "set the current volume level")
parser.add_argument(
"volume",
nargs="?",
@ -404,9 +333,6 @@ async def volume(message):
f"{int(message.guild.voice_client.source.volume * 100)}",
)
else:
if not command_allowed(message):
return
message.guild.voice_client.source.volume = float(args.volume) / 100.0
await utils.add_check_reaction(message)
@ -436,17 +362,17 @@ def play_after_callback(e, message, once):
def play_next(message, once=False, first=False):
message.guild.voice_client.stop()
if message.guild.id in players and players[message.guild.id].queue:
if players[message.guild.id].queue:
queued = players[message.guild.id].queue_pop()
try:
message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once)
)
except disnake.opus.OpusNotLoaded:
utils.load_opus()
message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once)
except Exception as e:
client.loop.create_task(
utils.channel_send(message, f"error while trying to play: `{e}`")
)
return
client.loop.create_task(
utils.channel_send(message, queued.format(show_queuer=not first))
)
@ -457,15 +383,10 @@ async def ensure_joined(message):
if message.author.voice:
await message.author.voice.channel.connect()
else:
await utils.reply(message, "you are not connected to a voice channel!")
await utils.reply(message, "You are not connected to a voice channel.")
def command_allowed(message, immutable=False):
if not message.guild.voice_client:
return
if immutable:
return message.channel.id == message.guild.voice_client.channel.id
else:
if not message.author.voice:
return False
return message.author.voice.channel.id == message.guild.voice_client.channel.id
def command_allowed(message):
if not message.author.voice or not message.guild.voice_client:
return False
return message.author.voice.channel.id == message.guild.voice_client.channel.id

View File

@ -1,20 +1,5 @@
import os
YTDL_OPTIONS = {
"color": "never",
"default_search": "auto",
"format": "bestaudio/best",
"ignoreerrors": False,
"logtostderr": False,
"no_warnings": True,
"noplaylist": True,
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
"quiet": True,
"restrictfilenames": True,
"socket_timeout": 15,
"source_address": "0.0.0.0",
}
EMBED_COLOR = 0xFF6600
OWNERS = [531392146767347712]
PREFIX = "%"
@ -34,6 +19,22 @@ RELOADABLE_MODULES = [
"youtubedl",
]
YTDL_OPTIONS = {
"color": "never",
"default_search": "auto",
"format": "bestaudio/best",
"ignoreerrors": False,
"logtostderr": False,
"no_warnings": True,
"noplaylist": True,
"outtmpl": "%(extractor)s-%(id)s-%(title)s.%(ext)s",
"quiet": True,
"restrictfilenames": True,
"socket_timeout": 15,
"source_address": "0.0.0.0",
}
SECRETS = {
"TOKEN": os.getenv("BOT_TOKEN"),
}

View File

@ -47,7 +47,6 @@ async def on_message(message, edited=False):
match matched[0]:
case C.RELOAD if message.author.id in constants.OWNERS:
reloaded_modules = set()
rreload(reloaded_modules, __import__("core"))
for module in filter(
lambda v: inspect.ismodule(v)
and v.__name__ in constants.RELOADABLE_MODULES,
@ -93,11 +92,13 @@ async def on_message(message, edited=False):
invalid_user_function=utils.invalid_user_handler,
color=constants.EMBED_COLOR,
segments=disnake_paginator.split(output),
).start(utils.MessageInteractionWrapper(message))
).start(
disnake_paginator.wrappers.MessageInteractionWrapper(message)
)
elif len(output.strip()) == 0:
await utils.add_check_reaction(message)
else:
await utils.reply(message, output)
await utils.channel_send(message, output)
case C.CLEAR | C.PURGE if message.author.id in constants.OWNERS:
await commands.tools.clear(message)
case C.JOIN:
@ -122,8 +123,6 @@ async def on_message(message, edited=False):
await commands.bot.uptime(message)
case C.PLAYING:
await commands.voice.playing(message)
case C.FAST_FORWARD:
await commands.voice.fast_forward(message)
except Exception as e:
await utils.reply(
message,

View File

@ -1,3 +0,0 @@
from . import test_format_duration
__all__ = ["test_format_duration"]

View File

@ -1,61 +0,0 @@
import unittest
import utils
import youtubedl
class TestFormatDuration(unittest.TestCase):
def test_youtubedl(self):
self.assertEqual(youtubedl.format_duration(0), "00:00")
self.assertEqual(youtubedl.format_duration(0.5), "00:00")
self.assertEqual(youtubedl.format_duration(60.5), "01:00")
self.assertEqual(youtubedl.format_duration(1), "00:01")
self.assertEqual(youtubedl.format_duration(60), "01:00")
self.assertEqual(youtubedl.format_duration(60 + 30), "01:30")
self.assertEqual(youtubedl.format_duration(60 * 60), "01:00:00")
self.assertEqual(youtubedl.format_duration(60 * 60 + 30), "01:00:30")
def test_utils(self):
self.assertEqual(utils.format_duration(0), "")
self.assertEqual(utils.format_duration(60 * 60 * 24 * 7), "1 week")
self.assertEqual(utils.format_duration(60 * 60 * 24 * 21), "3 weeks")
self.assertEqual(
utils.format_duration((60 * 60 * 24 * 21) - 1),
"2 weeks, 6 days, 23 hours, 59 minutes, 59 seconds",
)
self.assertEqual(utils.format_duration(60), "1 minute")
self.assertEqual(utils.format_duration(60 * 2), "2 minutes")
self.assertEqual(utils.format_duration(60 * 59), "59 minutes")
self.assertEqual(utils.format_duration(60 * 60), "1 hour")
self.assertEqual(utils.format_duration(60 * 60 * 2), "2 hours")
self.assertEqual(utils.format_duration(1), "1 second")
self.assertEqual(utils.format_duration(60 + 5), "1 minute, 5 seconds")
self.assertEqual(utils.format_duration(60 * 60 + 30), "1 hour, 30 seconds")
self.assertEqual(
utils.format_duration(60 * 60 + 60 + 30), "1 hour, 1 minute, 30 seconds"
)
self.assertEqual(
utils.format_duration(60 * 60 * 24 * 7 + 30), "1 week, 30 seconds"
)
def test_utils_natural(self):
def format(seconds: int):
return utils.format_duration(seconds, natural=True)
self.assertEqual(format(0), "")
self.assertEqual(format(60 * 60 * 24 * 7), "1 week")
self.assertEqual(format(60 * 60 * 24 * 21), "3 weeks")
self.assertEqual(
format((60 * 60 * 24 * 21) - 1),
"2 weeks, 6 days, 23 hours, 59 minutes and 59 seconds",
)
self.assertEqual(format(60), "1 minute")
self.assertEqual(format(60 * 2), "2 minutes")
self.assertEqual(format(60 * 59), "59 minutes")
self.assertEqual(format(60 * 60), "1 hour")
self.assertEqual(format(60 * 60 * 2), "2 hours")
self.assertEqual(format(1), "1 second")
self.assertEqual(format(60 + 5), "1 minute and 5 seconds")
self.assertEqual(format(60 * 60 + 30), "1 hour and 30 seconds")
self.assertEqual(format(60 * 60 + 60 + 30), "1 hour, 1 minute and 30 seconds")
self.assertEqual(format(60 * 60 * 24 * 7 + 30), "1 week and 30 seconds")

View File

@ -1,40 +1,10 @@
import os
import disnake
import constants
from state import message_responses
class ChannelResponseWrapper:
def __init__(self, message):
self.message = message
self.sent_message = None
async def send_message(self, **kwargs):
if "ephemeral" in kwargs:
del kwargs["ephemeral"]
self.sent_message = await reply(self.message, **kwargs)
async def edit_message(self, content=None, embed=None, view=None):
if self.sent_message:
content = content or self.sent_message.content
if not embed and len(self.sent_message.embeds) > 0:
embed = self.sent_message.embeds[0]
await self.sent_message.edit(content=content, embed=embed, view=view)
class MessageInteractionWrapper:
def __init__(self, message):
self.message = message
self.author = message.author
self.response = ChannelResponseWrapper(message)
async def edit_original_message(self, content=None, embed=None, view=None):
await self.response.edit_message(content=content, embed=embed, view=view)
def format_duration(duration: int, natural: bool = False):
def format_duration(duration: int):
def format_plural(noun, count):
return noun if count == 1 else noun + "s"
@ -59,10 +29,7 @@ def format_duration(duration: int, natural: bool = False):
if duration > 0:
segments.append(f"{duration} {format_plural('second', duration)}")
if not natural or len(segments) <= 1:
return ", ".join(segments)
return ", ".join(segments[:-1]) + f" and {segments[-1]}"
return ", ".join(segments)
async def add_check_reaction(message):
@ -79,18 +46,23 @@ async def reply(message, *args, **kwargs):
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
message_responses[message.id] = response
return message_responses[message.id]
async def channel_send(message, *args, **kwargs):
await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
if message.id in message_responses:
await message_responses[message.id].edit(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
else:
response = await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
message_responses[message.id] = response
async def invalid_user_handler(interaction):
await interaction.response.send_message(
"you are not the intended receiver of this message!", ephemeral=True
"You are not the intended receiver of this message!", ephemeral=True
)
@ -100,16 +72,3 @@ def filter_secrets(text: str) -> str:
continue
text = text.replace(secret, f"<{secret_name}>")
return text
def load_opus():
print("opus wasn't automatically loaded! trying to load manually...")
for path in ["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"]:
if os.path.exists(path):
try:
disnake.opus.load_opus(path)
print(f"successfully loaded opus from {path}")
return
except Exception as e:
print(f"failed to load opus from {path}: {e}")
raise Exception("could not locate working opus library")

View File

@ -11,33 +11,28 @@ import constants
ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS)
class CustomAudioSource(disnake.AudioSource):
class TrackedAudioSource(disnake.AudioSource):
def __init__(self, source):
self._source = source
self.read_count = 0
self.count = 0
def read(self) -> bytes:
data = self._source.read()
if data:
self.read_count += 1
self.count += 1
return data
def fast_forward(self, seconds: int):
for _ in range(int(seconds / 0.02)):
self.read()
@property
def progress(self) -> float:
return self.read_count * 0.02
return self.count * 0.02
class YTDLSource(disnake.PCMVolumeTransformer):
def __init__(
self, source: CustomAudioSource, *, data: dict[str, Any], volume: float = 0.5
self, source: TrackedAudioSource, *, data: dict[str, Any], volume: float = 0.5
):
super().__init__(source, volume)
self.description = data.get("description")
self.duration = data.get("duration")
self.original_url = data.get("original_url")
self.thumbnail_url = data.get("thumbnail")
@ -61,7 +56,7 @@ class YTDLSource(disnake.PCMVolumeTransformer):
data = data["entries"][0]
return cls(
CustomAudioSource(
TrackedAudioSource(
disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1",
@ -87,14 +82,14 @@ class QueuedSong:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})\n**duration:** {format_duration(self.player.duration) if self.player.duration else '[live]'}"
+ (
f", **queued by:** <@{self.trigger_message.author.id}>"
f", **queuer:** <@{self.trigger_message.author.id}>"
if show_queuer
else ""
)
)
else:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) [**{format_duration(self.player.duration) if self.player.duration else 'live'}**]"
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) **[{format_duration(self.player.duration) if self.player.duration else 'live'}]**"
+ (f" (<@{self.trigger_message.author.id}>)" if show_queuer else "")
)
@ -122,8 +117,8 @@ class QueuedPlayer:
return self.__repr__()
def format_duration(duration: int | float) -> str:
hours, duration = divmod(int(duration), 3600)
def format_duration(duration: int) -> str:
hours, duration = divmod(duration, 3600)
minutes, duration = divmod(duration, 60)
segments = [hours, minutes, duration]
if len(segments) == 3 and segments[0] == 0: