Compare commits

..

1 Commits

Author SHA1 Message Date
b71331a102 refactor: fix casing andn add type hints 2025-02-05 17:24:15 -05:00
33 changed files with 406 additions and 498 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
.env
.venv
__pycache__

View File

@@ -7,4 +7,4 @@ COPY . .
RUN pip install -r requirements.txt
CMD ["python", "main.py"]
CMD ["python", "-OO", "main.py"]

View File

@@ -8,9 +8,7 @@ import utils
class ArgumentParser:
def __init__(self, command, description):
self.parser = argparse.ArgumentParser(
command,
description=description,
exit_on_error=False,
command, description=description, exit_on_error=False
)
def print_help(self):
@@ -28,20 +26,21 @@ class ArgumentParser:
async def parse_args(self, message, tokens) -> argparse.Namespace | None:
try:
with contextlib.redirect_stdout(io.StringIO()):
return self.parser.parse_args(tokens[1:])
args = self.parser.parse_args(tokens[1:])
return args
except SystemExit:
await utils.reply(message, f"```\n{self.print_help()}```")
except Exception as e:
await utils.reply(message, f"`{e}`")
def range_type(string: str, lower=0, upper=100) -> int:
def range_type(string: str, min=0, max=100):
try:
value = int(string)
except ValueError as e:
raise argparse.ArgumentTypeError("value is not a valid integer") from e
except ValueError:
raise argparse.ArgumentTypeError("value is not a valid integer")
if lower <= value <= upper:
if min <= value <= max:
return value
raise argparse.ArgumentTypeError(f"value is not in range {lower}-{upper}")
else:
raise argparse.ArgumentTypeError(f"value is not in range {min}-{max}")

View File

@@ -1,8 +0,0 @@
from . import discord, queue, utils, youtubedl
__all__ = [
"discord",
"queue",
"utils",
"youtubedl",
]

View File

@@ -1,39 +0,0 @@
import audioop
import disnake
class TrackedAudioSource(disnake.AudioSource):
def __init__(self, source):
self._source = source
self.read_count = 0
def read(self) -> bytes:
data = self._source.read()
if data:
self.read_count += 1
return data
def fast_forward(self, seconds: int):
for _ in range(int(seconds / 0.02)):
self.read()
@property
def progress(self) -> float:
return self.read_count * 0.02
class PCMVolumeTransformer(disnake.AudioSource):
def __init__(self, original: TrackedAudioSource, volume: float = 1.0) -> None:
if original.is_opus():
raise disnake.ClientException("AudioSource must not be Opus encoded")
self.original = original
self.volume = volume
def cleanup(self) -> None:
self.original.cleanup()
def read(self) -> bytes:
ret = self.original.read()
return audioop.mul(ret, 2, self.volume)

View File

@@ -1,112 +0,0 @@
import collections
from dataclasses import dataclass
from typing import ClassVar, Optional
import disnake
from constants import BAR_LENGTH, EMBED_COLOR
from .utils import format_duration
from .youtubedl import YTDLSource
@dataclass
class Song:
player: YTDLSource
trigger_message: disnake.Message
def format(self, show_queuer=False, hide_preview=False, multiline=False) -> str:
title = f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})"
duration = (
format_duration(self.player.duration) if self.player.duration else "stream"
)
if multiline:
queue_time = (
self.trigger_message.edited_at or self.trigger_message.created_at
)
return f"{title}\n**duration:** {duration}" + (
f", **queued by:** <@{self.trigger_message.author.id}> <t:{round(queue_time.timestamp())}:R>"
if show_queuer
else ""
)
return f"{title} [**{duration}**]" + (
f" (<@{self.trigger_message.author.id}>)" if show_queuer else ""
)
def embed(self, is_paused=False):
progress = 0
if self.player.duration:
progress = self.player.original.progress / self.player.duration
embed = disnake.Embed(
color=EMBED_COLOR,
title=self.player.title,
url=self.player.original_url,
description=(
f"{'⏸️ ' if is_paused else ''}"
f"`[{'#' * int(progress * BAR_LENGTH)}{'-' * int((1 - progress) * BAR_LENGTH)}]` "
+ (
f"**{format_duration(int(self.player.original.progress))}** / **{format_duration(self.player.duration)}** (**{round(progress * 100)}%**)"
if self.player.duration
else "[**stream**]"
)
),
timestamp=self.trigger_message.edited_at or self.trigger_message.created_at,
)
uploader_value = None
if self.player.uploader_url:
if self.player.uploader:
uploader_value = f"[{self.player.uploader}]({self.player.uploader_url})"
else:
uploader_value = self.player.uploader_url
elif self.player.uploader:
uploader_value = self.player.uploader
if uploader_value:
embed.add_field(name="Uploader", value=uploader_value)
if self.player.like_count:
embed.add_field(name="Likes", value=f"{self.player.like_count:,}")
if self.player.view_count:
embed.add_field(name="Views", value=f"{self.player.view_count:,}")
if self.player.timestamp:
embed.add_field(name="Published", value=f"<t:{int(self.player.timestamp)}>")
if self.player.volume:
embed.add_field(name="Volume", value=f"{int(self.player.volume * 100)}%")
if self.player.thumbnail_url:
embed.set_image(self.player.thumbnail_url)
embed.set_footer(
text=f"Queued by {self.trigger_message.author.name}",
icon_url=(
self.trigger_message.author.avatar.url
if self.trigger_message.author.avatar
else None
),
)
return embed
def __str__(self):
return self.__repr__()
@dataclass
class Player:
queue: ClassVar = collections.deque()
current: Optional[Song] = None
def queue_pop(self):
popped = self.queue.popleft()
self.current = popped
return popped
def queue_push(self, item):
self.queue.append(item)
def queue_push_front(self, item):
self.queue.appendleft(item)
def __str__(self):
return self.__repr__()

View File

@@ -1,7 +0,0 @@
def format_duration(duration: int | float) -> str:
hours, duration = divmod(int(duration), 3600)
minutes, duration = divmod(duration, 60)
segments = [hours, minutes, duration]
if len(segments) == 3 and segments[0] == 0:
del segments[0]
return f"{':'.join(f'{s:0>2}' for s in segments)}"

View File

@@ -1,76 +0,0 @@
import asyncio
from typing import Any, Optional
import disnake
import yt_dlp
from constants import YTDL_OPTIONS
from .discord import PCMVolumeTransformer, TrackedAudioSource
ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)
class YTDLSource(PCMVolumeTransformer):
def __init__(
self,
source: TrackedAudioSource,
*,
data: dict[str, Any],
volume: float = 0.5,
):
super().__init__(source, volume)
self.description = data.get("description")
self.duration = data.get("duration")
self.id = data.get("id")
self.like_count = data.get("like_count")
self.original_url = data.get("original_url")
self.thumbnail_url = data.get("thumbnail")
self.timestamp = data.get("timestamp")
self.title = data.get("title")
self.uploader = data.get("uploader")
self.uploader_url = data.get("uploader_url")
self.view_count = data.get("view_count")
@classmethod
async def from_url(
cls,
url,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
stream: bool = False,
):
loop = loop or asyncio.get_event_loop()
data: Any = await loop.run_in_executor(
None,
lambda: ytdl.extract_info(url, download=not stream),
)
if "entries" in data:
if not data["entries"]:
raise Exception("no results found!")
data = data["entries"][0]
if "url" not in data:
raise Exception("no url returned!")
return cls(
TrackedAudioSource(
disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1",
),
),
data=data,
)
def __repr__(self):
return f"<YTDLSource title={self.title} original_url={self.original_url} duration={self.duration}>"
def __str__(self):
return self.__repr__()
def __reload_module__():
global ytdl
ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)

View File

@@ -3,11 +3,11 @@ from .utils import Command, match, match_token, tokenize
__all__ = [
"bot",
"tools",
"utils",
"voice",
"Command",
"match",
"match_token",
"tokenize",
"tools",
"utils",
"voice",
]

View File

@@ -8,9 +8,9 @@ from yt_dlp import version
import arguments
import commands
import utils
from constants import EMBED_COLOR
from state import client, start_time
from utils import format_duration, reply, surround
async def status(message):
@@ -25,41 +25,41 @@ async def status(message):
embed = disnake.Embed(color=EMBED_COLOR)
embed.add_field(
name="Latency",
value=surround(f"{round(client.latency * 1000, 1)} ms"),
value=f"```{round(client.latency * 1000, 1)} ms```",
)
embed.add_field(
name="Memory",
value=surround(f"{round(memory_usage, 1)} MiB"),
value=f"```{round(memory_usage, 1)} MiB```",
)
embed.add_field(
name="Threads",
value=surround(threading.active_count()),
value=f"```{threading.active_count()}```",
)
embed.add_field(
name="Guilds",
value=surround(len(client.guilds)),
value=f"```{len(client.guilds)}```",
)
embed.add_field(
name="Members",
value=surround(member_count),
value=f"```{member_count}```",
)
embed.add_field(
name="Channels",
value=surround(channel_count),
value=f"```{channel_count}```",
)
embed.add_field(
name="Disnake",
value=surround(disnake.__version__),
value=f"```{disnake.__version__}```",
)
embed.add_field(
name="yt-dlp",
value=surround(version.__version__),
value=f"```{version.__version__}```",
)
embed.add_field(
name="Uptime",
value=surround(format_duration(int(time.time() - start_time), short=True)),
value=f"```{utils.format_duration(int(time.time() - start_time), short=True)}```",
)
await reply(message, embed=embed)
await utils.reply(message, embed=embed)
async def uptime(message):
@@ -78,13 +78,15 @@ async def uptime(message):
return
if args.since:
await reply(message, f"{round(start_time)}")
await utils.reply(message, f"{round(start_time)}")
else:
await reply(message, f"up {format_duration(int(time.time() - start_time))}")
await utils.reply(
message, f"up {utils.format_duration(int(time.time() - start_time))}"
)
async def ping(message):
await reply(
await utils.reply(
message,
embed=disnake.Embed(
title="Pong :ping_pong:",
@@ -95,9 +97,9 @@ async def ping(message):
async def help(message):
await reply(
await utils.reply(
message,
", ".join(
[f"`{command.value}`" for command in commands.Command.__members__.values()],
[f"`{command.value}`" for command in commands.Command.__members__.values()]
),
)

View File

@@ -14,13 +14,13 @@ async def lookup(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0],
"look up a discord user or application by ID",
"look up a user or application on discord by their ID",
)
parser.add_argument(
"-a",
"--application",
action="store_true",
help="look up applications instead of users",
help="search for applications instead of users",
)
parser.add_argument(
"id",
@@ -41,7 +41,7 @@ async def lookup(message):
embed = disnake.Embed(description=response["description"], color=EMBED_COLOR)
embed.set_thumbnail(
url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp",
url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp"
)
embed.add_field(name="Application Name", value=response["name"])
embed.add_field(name="Application ID", value="`" + response["id"] + "`")
@@ -102,9 +102,7 @@ async def lookup(message):
for tag in response["tags"]:
bot_tags += tag + ", "
embed.add_field(
name="Tags",
value="None" if bot_tags == "" else bot_tags[:-2],
inline=False,
name="Tags", value="None" if bot_tags == "" else bot_tags[:-2], inline=False
)
else:
try:
@@ -119,10 +117,8 @@ async def lookup(message):
if flag_name != "None":
try:
badges += BADGE_EMOJIS[PUBLIC_FLAGS[flag]]
except Exception as e:
raise Exception(
f"unable to find badge: {PUBLIC_FLAGS[flag]}"
) from e
except Exception:
raise Exception(f"unable to find badge: {PUBLIC_FLAGS[flag]}")
user_object = await client.fetch_user(user.id)
accent_color = 0x000000
@@ -165,11 +161,11 @@ async def clear(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0],
"bulk delete messages in the current channel matching specified criteria",
"bulk delete messages in the current channel matching certain criteria",
)
parser.add_argument(
"count",
type=lambda c: arguments.range_type(c, lower=1, upper=1000),
type=lambda c: arguments.range_type(c, min=1, max=1000),
help="amount of messages to delete",
)
group = parser.add_mutually_exclusive_group()
@@ -263,10 +259,8 @@ async def clear(message):
messages = len(
await message.channel.purge(
limit=args.count,
check=check,
oldest_first=args.oldest_first,
),
limit=args.count, check=check, oldest_first=args.oldest_first
)
)
if not args.delete_command:

View File

@@ -30,19 +30,16 @@ class Command(Enum):
@lru_cache
def match_token(token: str) -> list[Command]:
match token.lower():
case "r":
return [Command.RELOAD]
case "s":
return [Command.SKIP]
case "c":
return [Command.CURRENT]
if token.lower() == "r":
return [Command.RELOAD]
elif token.lower() == "s":
return [Command.SKIP]
if exact_match := list(
filter(
lambda command: command.value == token.lower(),
Command.__members__.values(),
),
)
):
return exact_match
@@ -50,7 +47,7 @@ def match_token(token: str) -> list[Command]:
filter(
lambda command: command.value.startswith(token.lower()),
Command.__members__.values(),
),
)
)

View File

@@ -1,14 +1,17 @@
import disnake
import utils
from .utils import command_allowed
async def join(message):
if message.author.voice:
if message.guild.voice_client:
await message.guild.voice_client.move_to(message.channel)
else:
await message.author.voice.channel.connect()
if message.guild.voice_client:
return await message.guild.voice_client.move_to(message.channel)
elif message.author.voice:
await message.author.voice.channel.connect()
elif isinstance(message.channel, disnake.VoiceChannel):
await message.channel.connect()
else:
await utils.reply(message, "you are not connected to a voice channel!")
return

View File

@@ -1,11 +1,11 @@
import disnake_paginator
import arguments
import disnake_paginator
from constants import EMBED_COLOR
from state import players
import commands
import sponsorblock
import utils
from constants import EMBED_COLOR
from state import players
from .utils import command_allowed
@@ -13,8 +13,7 @@ from .utils import command_allowed
async def playing(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0],
"get information about the currently playing song",
tokens[0], "get information about the currently playing song"
)
parser.add_argument(
"-d",
@@ -50,7 +49,7 @@ async def playing(message):
await utils.reply(
message,
embed=players[message.guild.id].current.embed(
is_paused=message.guild.voice_client.is_paused(),
is_paused=message.guild.voice_client.is_paused()
),
)
else:
@@ -90,13 +89,13 @@ async def pause(message):
async def fast_forward(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "skip the current sponsorblock segment")
parser = arguments.ArgumentParser(tokens[0], "skip current sponsorblock segment")
parser.add_argument(
"-s",
"--seconds",
nargs="?",
type=lambda v: arguments.range_type(v, lower=0, upper=300),
help="the number of seconds to fast forward instead",
type=lambda v: arguments.range_type(v, min=0, max=300),
help="the amount of seconds to fast forward instead",
)
if not (args := await parser.parse_args(message, tokens)):
return
@@ -111,12 +110,11 @@ async def fast_forward(message):
seconds = args.seconds
if not seconds:
video = await sponsorblock.get_segments(
players[message.guild.id].current.player.id,
players[message.guild.id].current.player.id
)
if not video:
await utils.reply(
message,
"no sponsorblock segments were found for this video!",
message, "no sponsorblock segments were found for this video!"
)
return
@@ -142,7 +140,7 @@ async def volume(message):
parser.add_argument(
"volume",
nargs="?",
type=lambda v: arguments.range_type(v, lower=0, upper=150),
type=lambda v: arguments.range_type(v, min=0, max=150),
help="the volume level (0 - 150)",
)
if not (args := await parser.parse_args(message, tokens)):

View File

@@ -4,28 +4,30 @@ import disnake
import disnake_paginator
import arguments
import audio
import commands
import utils
import youtubedl
from constants import EMBED_COLOR
from state import client, players, trusted_users
from state import client, players
from .playback import resume
from .utils import command_allowed, ensure_joined, play_next
async def queue_or_play(message, edited=False):
if message.guild.id not in players:
players[message.guild.id] = youtubedl.QueuedPlayer()
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0],
"queue a song, list the queue, or resume playback",
tokens[0], "queue a song, list the queue, or resume playback"
)
parser.add_argument("query", nargs="*", help="yt-dlp URL or query to get song")
parser.add_argument(
"-v",
"--volume",
default=50,
type=lambda v: arguments.range_type(v, lower=0, upper=150),
type=lambda v: arguments.range_type(v, min=0, max=150),
help="the volume level (0 - 150) for the specified song",
)
parser.add_argument(
@@ -78,23 +80,19 @@ async def queue_or_play(message, edited=False):
elif not command_allowed(message):
return
if message.guild.id not in players:
players[message.guild.id] = audio.queue.Player()
if edited:
found = next(
filter(
lambda queued: queued.trigger_message.id == message.id,
players[message.guild.id].queue,
),
None,
)
found = None
for queued in players[message.guild.id].queue:
if queued.trigger_message.id == message.id:
found = queued
break
if found:
players[message.guild.id].queue.remove(found)
if args.clear:
players[message.guild.id].queue.clear()
await utils.add_check_reaction(message)
return
elif indices := args.remove_index:
targets = []
for i in indices:
@@ -115,15 +113,15 @@ async def queue_or_play(message, edited=False):
f"removed **{len(targets)}** queued {'song' if len(targets) == 1 else 'songs'}",
)
elif args.remove_title or args.remove_queuer:
targets = set()
targets = []
for queued in players[message.guild.id].queue:
if t := args.remove_title:
if t in queued.player.title:
targets.add(queued)
targets.append(queued)
continue
if q := args.remove_queuer:
if q == queued.trigger_message.author.id:
targets.add(queued)
targets = list(targets)
targets.append(queued)
if not args.match_multiple:
targets = targets[:1]
@@ -142,12 +140,11 @@ async def queue_or_play(message, edited=False):
lambda queued: queued.trigger_message.author.id
== message.author.id,
players[message.guild.id].queue,
),
),
)
)
)
>= 5
and not len(message.guild.voice_client.channel.members) == 2
and message.author.id not in trusted_users
):
await utils.reply(
message,
@@ -157,22 +154,22 @@ async def queue_or_play(message, edited=False):
try:
async with message.channel.typing():
player = await audio.youtubedl.YTDLSource.from_url(
" ".join(query),
loop=client.loop,
stream=True,
player = await youtubedl.YTDLSource.from_url(
" ".join(query), loop=client.loop, stream=True
)
player.volume = float(args.volume) / 100.0
except Exception as e:
await utils.reply(message, f"failed to queue: `{e}`")
await utils.reply(
message, f"**failed to queue:** `{e}`", suppress_embeds=True
)
return
queued = audio.queue.Song(player, message)
queued = youtubedl.QueuedSong(player, message)
if args.now or args.next:
players[message.guild.id].queue_push_front(queued)
players[message.guild.id].queue_add_front(queued)
else:
players[message.guild.id].queue_push(queued)
players[message.guild.id].queue_add(queued)
if not message.guild.voice_client:
await utils.reply(message, "unexpected disconnect from voice channel!")
@@ -197,7 +194,7 @@ async def queue_or_play(message, edited=False):
[
queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue
],
]
),
natural=True,
)
@@ -222,14 +219,13 @@ async def queue_or_play(message, edited=False):
[
f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}"
for i, queued in batch
],
]
)
for batch in itertools.batched(
enumerate(players[message.guild.id].queue),
10,
enumerate(players[message.guild.id].queue), 10
)
],
),
)
),
).start(utils.MessageInteractionWrapper(message))
else:
@@ -241,7 +237,7 @@ async def queue_or_play(message, edited=False):
async def skip(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(tokens[0], "skip the song currently playing")
parser = arguments.ArgumentParser(tokens[0], "skip the currently playing song")
parser.add_argument(
"-n",
"--next",

View File

@@ -1,9 +1,9 @@
import disnake
import audio
import sponsorblock
import utils
from constants import EMBED_COLOR, SPONSORBLOCK_CATEGORY_NAMES
import youtubedl
from constants import EMBED_COLOR
from state import players
from .utils import command_allowed
@@ -21,20 +21,18 @@ async def sponsorblock_command(message):
video = await sponsorblock.get_segments(players[message.guild.id].current.player.id)
if not video:
await utils.reply(
message,
"no sponsorblock segments were found for this video!",
message, "no sponsorblock segments were found for this video!"
)
return
text = []
for segment in video["segments"]:
begin, end = map(int, segment["segment"])
if (category := segment["category"]) in SPONSORBLOCK_CATEGORY_NAMES:
category = SPONSORBLOCK_CATEGORY_NAMES[category]
category_name = sponsorblock.CATEGORY_NAMES.get(segment["category"])
current = "**" if progress >= begin and progress < end else ""
text.append(
f"{current}`{audio.utils.format_duration(begin)}` - `{audio.utils.format_duration(end)}`: {category}{current}",
f"{current}`{youtubedl.format_duration(begin)}` - `{youtubedl.format_duration(end)}`: {category_name if category_name else 'Unknown'}{current}"
)
await utils.reply(

View File

@@ -24,8 +24,7 @@ def play_next(message, once=False, first=False):
if message.guild.id in players and players[message.guild.id].queue:
queued = players[message.guild.id].queue_pop()
message.guild.voice_client.play(
queued.player,
after=lambda e: play_after_callback(e, message, once),
queued.player, after=lambda e: play_after_callback(e, message, once)
)
embed = queued.embed()

View File

@@ -19,24 +19,8 @@ BAR_LENGTH = 35
EMBED_COLOR = 0xFF6600
OWNERS = [531392146767347712]
PREFIX = "%"
SPONSORBLOCK_CATEGORY_NAMES = {
"music_offtopic": "non-music",
"selfpromo": "self promotion",
"sponsor": "sponsored",
}
REACTIONS = {
"cat": ["🐈"],
"dog": ["🐕"],
"gn": ["💤", "😪", "😴", "🛌"],
"pizza": ["🍕"],
}
RELOADABLE_MODULES = [
"arguments",
"audio",
"audio.discord",
"audio.queue",
"audio.utils",
"audio.youtubedl",
"commands",
"commands.bot",
"commands.tools",
@@ -56,28 +40,27 @@ RELOADABLE_MODULES = [
"sponsorblock",
"tasks",
"utils",
"utils.common",
"utils.discord",
"voice",
"youtubedl",
"yt_dlp",
"yt_dlp.version",
]
PUBLIC_FLAGS = {
1 << 0: "Discord Employee",
1 << 1: "Discord Partner",
1 << 2: "HypeSquad Events",
1 << 3: "Bug Hunter Level 1",
1 << 6: "HypeSquad Bravery",
1 << 7: "HypeSquad Brilliance",
1 << 8: "HypeSquad Balance",
1 << 9: "Early Supporter",
1 << 10: "Team User",
1 << 14: "Bug Hunter Level 2",
1 << 16: "Verified Bot",
1 << 17: "Verified Bot Developer",
1 << 18: "Discord Certified Moderator",
1 << 19: "HTTP Interactions Only",
1 << 22: "Active Developer",
1: "Discord Employee",
2: "Discord Partner",
4: "HypeSquad Events",
8: "Bug Hunter Level 1",
64: "HypeSquad Bravery",
128: "HypeSquad Brilliance",
256: "HypeSquad Balance",
512: "Early Supporter",
1024: "Team User",
16384: "Bug Hunter Level 2",
65536: "Verified Bot",
131072: "Verified Bot Developer",
262144: "Discord Certified Moderator",
524288: "HTTP Interactions Only",
4194304: "Active Developer",
}
BADGE_EMOJIS = {
"Discord Employee": "<:DiscordStaff:879666899980546068>",

47
core.py
View File

@@ -3,7 +3,6 @@ import contextlib
import importlib
import inspect
import io
import signal
import textwrap
import time
import traceback
@@ -49,23 +48,33 @@ async def on_message(message, edited=False):
try:
if (cooldowns := command_cooldowns.get(message.author.id)) and not edited:
if (end_time := cooldowns.get(matched)) and (
remaining_time := round(end_time - time.time()) > 0
if (end_time := cooldowns.get(matched)) and int(time.time()) < int(
end_time
):
await utils.reply(
message,
f"please wait **{utils.format_duration(remaining_time, natural=True)}** before using this command again!",
f"please wait **{utils.format_duration(int(end_time - time.time()), natural=True)}** before using this command again!",
)
return
match matched:
case C.RELOAD if message.author.id in OWNERS:
reloaded_modules = set()
start = time.time()
reloaded_modules = reload()
rreload(reloaded_modules, __import__("core"))
rreload(reloaded_modules, __import__("extra"))
for module in filter(
lambda v: inspect.ismodule(v) and v.__name__ in RELOADABLE_MODULES,
globals().values(),
):
rreload(reloaded_modules, module)
end = time.time()
debug(
f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s",
)
if __debug__:
debug(
f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s"
)
await utils.add_check_reaction(message)
@@ -158,7 +167,6 @@ async def on_voice_state_update(_, before, after):
channel = before.channel
elif is_empty(after.channel):
channel = after.channel
if channel:
await channel.guild.voice_client.disconnect()
@@ -167,9 +175,9 @@ def rreload(reloaded_modules, module):
reloaded_modules.add(module.__name__)
for submodule in filter(
lambda sm: inspect.ismodule(sm)
and sm.__name__ in RELOADABLE_MODULES
and sm.__name__ not in reloaded_modules,
lambda v: inspect.ismodule(v)
and v.__name__ in RELOADABLE_MODULES
and v.__name__ not in reloaded_modules,
vars(module).values(),
):
rreload(reloaded_modules, submodule)
@@ -178,18 +186,3 @@ def rreload(reloaded_modules, module):
if "__reload_module__" in dir(module):
module.__reload_module__()
def reload(*_):
reloaded_modules = set()
rreload(reloaded_modules, __import__("core"))
rreload(reloaded_modules, __import__("extra"))
for module in filter(
lambda v: inspect.ismodule(v) and v.__name__ in RELOADABLE_MODULES,
globals().values(),
):
rreload(reloaded_modules, module)
return reloaded_modules
signal.signal(signal.SIGUSR1, reload)

View File

@@ -8,15 +8,11 @@ from state import client, kill, players
async def transcript(
message,
languages=["en"],
max_messages=6,
min_messages=3,
upper=True,
message, languages=["en"], max_messages=6, min_messages=3, upper=True
):
initial_id = message.guild.voice_client.source.id
transcript_list = youtube_transcript_api.YouTubeTranscriptApi.list_transcripts(
initial_id,
initial_id
)
try:
transcript = transcript_list.find_manually_created_transcript(languages).fetch()
@@ -48,7 +44,7 @@ async def transcript(
await messages.pop().delete()
else:
await message.channel.delete_messages(
[messages.pop() for _ in range(count)],
[messages.pop() for _ in range(count)]
)
except Exception:
pass
@@ -81,20 +77,19 @@ def messages_per_second(limit=500):
average = 1
print(
f"I am receiving **{average} {'message' if average == 1 else 'messages'} per second** "
f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**",
f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**"
)
async def auto_count(channel_id: int):
if (channel := await client.fetch_channel(channel_id)) and isinstance(
channel,
disnake.TextChannel,
channel, disnake.TextChannel
):
last_message = (await channel.history(limit=1).flatten())[0]
try:
result = str(
int("".join(filter(lambda d: d in string.digits, last_message.content)))
+ 1,
+ 1
)
except Exception:
result = "where number"

11
fun.py
View File

@@ -1,13 +1,10 @@
import random
import commands
from constants import REACTIONS
async def on_message(message):
if random.random() < 0.01:
tokens = commands.tokenize(message.content, remove_prefix=False)
for keyword, options in REACTIONS.items():
if keyword in tokens:
await message.add_reaction(random.choice(options))
break
if random.random() < 0.01 and "gn" in commands.tokenize(
message.content, remove_prefix=False
):
await message.add_reaction(random.choice(["💤", "😪", "😴", "🛌"]))

View File

@@ -7,7 +7,7 @@ from state import client
if __name__ == "__main__":
logging.basicConfig(
format=(
"%(asctime)s %(levelname)s %(name)s:%(module)s %(message)s"
"%(asctime)s %(levelname)s %(name):%(module)s %(message)s"
if __debug__
else "%(asctime)s %(levelname)s %(message)s"
),

View File

@@ -5,4 +5,4 @@ disnake_paginator
psutil
PyNaCl
youtube_transcript_api
yt-dlp[default] @ https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
yt-dlp

View File

@@ -1,36 +1,28 @@
import hashlib
import json
import aiohttp
from state import sponsorblock_cache
categories = json.dumps(
[
"interaction",
"intro",
"music_offtopic",
"outro",
"preview",
"selfpromo",
"sponsor",
],
)
CATEGORY_NAMES = {
"music_offtopic": "non-music",
"sponsor": "sponsored",
}
async def get_segments(video_id: str):
if video_id in sponsorblock_cache:
return sponsorblock_cache[video_id]
hash_prefix = hashlib.sha256(video_id.encode()).hexdigest()[:4]
hashPrefix = hashlib.sha256(video_id.encode()).hexdigest()[:4]
session = aiohttp.ClientSession()
response = await session.get(
f"https://sponsor.ajay.app/api/skipSegments/{hash_prefix}",
params={"categories": categories},
f"https://sponsor.ajay.app/api/skipSegments/{hashPrefix}",
params={"categories": '["sponsor", "music_offtopic"]'},
)
if response.status == 200 and (
results := list(
filter(lambda v: video_id == v["videoID"], await response.json()),
filter(lambda v: video_id == v["videoID"], await response.json())
)
):
sponsorblock_cache[video_id] = results[0]

View File

@@ -15,6 +15,5 @@ idle_tracker = {"is_idle": False, "last_used": time.time()}
kill = {"transcript": False}
message_responses = LimitedSizeDict()
players = {}
sponsorblock_cache = LimitedSizeDict()
sponsorblock_cache = LimitedSizeDict(size_limit=100)
start_time = time.time()
trusted_users = []

View File

@@ -11,7 +11,7 @@ async def cleanup():
debug("spawned cleanup thread")
while True:
await asyncio.sleep(3600)
await asyncio.sleep(3600 * 12)
targets = []
for guild_id, player in players.items():
@@ -19,7 +19,8 @@ async def cleanup():
targets.append(guild_id)
for target in targets:
del players[target]
debug(f"cleanup thread removed {len(targets)} empty players")
if __debug__:
debug(f"cleanup removed {len(targets)} empty players")
if (
not idle_tracker["is_idle"]

View File

@@ -1,3 +1,3 @@
from . import test_filter_secrets, test_format_duration
__all__ = ["test_filter_secrets", "test_format_duration"]
__all__ = ["test_format_duration", "test_filter_secrets"]

View File

@@ -7,15 +7,15 @@ class TestFilterSecrets(unittest.TestCase):
def test_filter_secrets(self):
secret = "PLACEHOLDER_TOKEN"
self.assertFalse(
secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret}),
secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret})
)
self.assertFalse(secret in utils.filter_secrets(secret, {"TOKEN": secret}))
self.assertFalse(
secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret}),
secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret})
)
self.assertFalse(
secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret}),
secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret})
)
self.assertFalse(
secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret}),
secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret})
)

View File

@@ -1,13 +1,13 @@
import unittest
import audio
import utils
import youtubedl
class TestFormatDuration(unittest.TestCase):
def test_audio(self):
def test_youtubedl(self):
def f(s):
return audio.utils.format_duration(s)
return youtubedl.format_duration(s)
self.assertEqual(f(0), "00:00")
self.assertEqual(f(0.5), "00:00")

View File

@@ -1,4 +1,4 @@
from .common import LimitedSizeDict, filter_secrets, format_duration, surround
from .common import LimitedSizeDict, filter_secrets, format_duration
from .discord import (
ChannelResponseWrapper,
MessageInteractionWrapper,
@@ -24,5 +24,4 @@ __all__ = [
"MessageInteractionWrapper",
"reply",
"snowflake_timestamp",
"surround",
]

View File

@@ -3,11 +3,7 @@ from collections import OrderedDict
from constants import SECRETS
def surround(inner: str, outer="```") -> str:
return outer + str(inner) + outer
def format_duration(duration: int, natural: bool = False, short: bool = False) -> str:
def format_duration(duration: int, natural: bool = False, short: bool = False):
def format_plural(noun, count):
if short:
return noun[0]
@@ -50,7 +46,7 @@ def filter_secrets(text: str, secrets=SECRETS) -> str:
class LimitedSizeDict(OrderedDict):
def __init__(self, *args, **kwargs):
self.size_limit = kwargs.pop("size_limit", 100)
self.size_limit = kwargs.pop("size_limit", 1000)
super().__init__(*args, **kwargs)
self._check_size_limit()

View File

@@ -1,18 +1,14 @@
import os
import time
from logging import error, info
from pathlib import Path
import disnake
import commands
from constants import OWNERS
from state import command_cooldowns, message_responses
def cooldown(message, cooldown_time: int):
if message.author.id in OWNERS:
return
possible_commands = commands.match(message.content)
if not possible_commands or len(possible_commands) > 1:
return
@@ -34,9 +30,7 @@ async def reply(message, *args, **kwargs):
try:
await message_responses[message.id].edit(
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
return
except Exception:
@@ -44,9 +38,7 @@ async def reply(message, *args, **kwargs):
try:
response = await message.reply(
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
except Exception:
response = await channel_send(message, *args, **kwargs)
@@ -56,15 +48,13 @@ async def reply(message, *args, **kwargs):
async def channel_send(message, *args, **kwargs):
await message.channel.send(
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
)
def load_opus():
for path in filter(
lambda p: Path(p).exists(),
lambda p: os.path.exists(p),
["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"],
):
try:
@@ -76,8 +66,8 @@ def load_opus():
raise Exception("could not locate working opus library")
def snowflake_timestamp(snowflake) -> int:
return round(((snowflake >> 22) + 1420070400000) / 1000)
def snowflake_timestamp(id):
return round(((id >> 22) + 1420070400000) / 1000)
async def add_check_reaction(message):
@@ -86,8 +76,7 @@ async def add_check_reaction(message):
async def invalid_user_handler(interaction):
await interaction.response.send_message(
"you are not the intended receiver of this message!",
ephemeral=True,
"you are not the intended receiver of this message!", ephemeral=True
)
@@ -97,7 +86,8 @@ class ChannelResponseWrapper:
self.sent_message = None
async def send_message(self, **kwargs):
kwargs.pop("ephemeral", None)
if "ephemeral" in kwargs:
del kwargs["ephemeral"]
self.sent_message = await reply(self.message, **kwargs)
async def edit_message(self, content=None, embed=None, view=None):

220
youtubedl.py Normal file
View File

@@ -0,0 +1,220 @@
import asyncio
import audioop
import collections
from dataclasses import dataclass
from typing import Any, Optional
import disnake
import yt_dlp
from constants import BAR_LENGTH, EMBED_COLOR, YTDL_OPTIONS
ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)
class PCMVolumeTransformer(disnake.AudioSource):
def __init__(self, original: disnake.AudioSource, volume: float = 1.0) -> None:
if original.is_opus():
raise disnake.ClientException("AudioSource must not be Opus encoded.")
self.original = original
self.volume = volume
@property
def volume(self) -> float:
return self._volume
@volume.setter
def volume(self, value: float) -> None:
self._volume = max(value, 0.0)
def cleanup(self) -> None:
self.original.cleanup()
def read(self) -> bytes:
ret = self.original.read()
return audioop.mul(ret, 2, self._volume)
class CustomAudioSource(disnake.AudioSource):
def __init__(self, source):
self._source = source
self.read_count = 0
def read(self) -> bytes:
data = self._source.read()
if data:
self.read_count += 1
return data
def fast_forward(self, seconds: int):
for _ in range(int(seconds / 0.02)):
self.read()
@property
def progress(self) -> float:
return self.read_count * 0.02
class YTDLSource(PCMVolumeTransformer):
def __init__(
self, source: CustomAudioSource, *, data: dict[str, Any], volume: float = 0.5
):
super().__init__(source, volume)
self.description = data.get("description")
self.duration = data.get("duration")
self.id = data.get("id")
self.like_count = data.get("like_count")
self.original_url = data.get("original_url")
self.thumbnail_url = data.get("thumbnail")
self.timestamp = data.get("timestamp")
self.title = data.get("title")
self.uploader = data.get("uploader")
self.uploader_url = data.get("uploader_url")
self.view_count = data.get("view_count")
@classmethod
async def from_url(
cls,
url,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
stream: bool = False,
):
loop = loop or asyncio.get_event_loop()
data: Any = await loop.run_in_executor(
None, lambda: ytdl.extract_info(url, download=not stream)
)
if "entries" in data:
if not data["entries"]:
raise Exception("no entries provided by yt-dlp!")
data = data["entries"][0]
return cls(
CustomAudioSource(
disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1",
)
),
data=data,
)
def __repr__(self):
return f"<YTDLSource title={self.title} original_url=<{self.original_url}> duration={self.duration}>"
def __str__(self):
return self.__repr__()
@dataclass
class QueuedSong:
player: YTDLSource
trigger_message: disnake.Message
def format(self, show_queuer=False, hide_preview=False, multiline=False) -> str:
if multiline:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})\n**duration:** {format_duration(self.player.duration) if self.player.duration else '[live]'}"
+ (
f", **queued by:** <@{self.trigger_message.author.id}>"
if show_queuer
else ""
)
)
else:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) [**{format_duration(self.player.duration) if self.player.duration else 'live'}**]"
+ (f" (<@{self.trigger_message.author.id}>)" if show_queuer else "")
)
def embed(self, is_paused=False):
progress = 0
if self.player.duration:
progress = self.player.original.progress / self.player.duration
embed = disnake.Embed(
color=EMBED_COLOR,
title=self.player.title,
url=self.player.original_url,
description=(
f"{'⏸️ ' if is_paused else ''}"
f"`[{'#' * int(progress * BAR_LENGTH)}{'-' * int((1 - progress) * BAR_LENGTH)}]` "
+ (
f"**{format_duration(int(self.player.original.progress))}** / **{format_duration(self.player.duration)}** (**{round(progress * 100)}%**)"
if self.player.duration
else "[**live**]"
)
),
)
if self.player.uploader_url:
embed.add_field(
name="Uploader",
value=f"[{self.player.uploader}]({self.player.uploader_url})",
)
else:
embed.add_field(
name="Uploader",
value=self.player.uploader,
)
embed.add_field(
name="Likes",
value=f"{self.player.like_count:,}"
if self.player.like_count
else "Unknown",
)
embed.add_field(name="Views", value=f"{self.player.view_count:,}")
embed.add_field(name="Published", value=f"<t:{self.player.timestamp}>")
embed.add_field(name="Volume", value=f"{int(self.player.volume * 100)}%")
embed.set_image(self.player.thumbnail_url)
embed.set_footer(
text=f"queued by {self.trigger_message.author.name}",
icon_url=(
self.trigger_message.author.avatar.url
if self.trigger_message.author.avatar
else None
),
)
return embed
def __str__(self):
return self.__repr__()
@dataclass
class QueuedPlayer:
queue = collections.deque()
current: Optional[QueuedSong] = None
def queue_pop(self):
popped = self.queue.popleft()
self.current = popped
return popped
def queue_add(self, item):
self.queue.append(item)
def queue_add_front(self, item):
self.queue.appendleft(item)
def __str__(self):
return self.__repr__()
def format_duration(duration: int | float) -> str:
hours, duration = divmod(int(duration), 3600)
minutes, duration = divmod(duration, 60)
segments = [hours, minutes, duration]
if len(segments) == 3 and segments[0] == 0:
del segments[0]
return f"{':'.join(f'{s:0>2}' for s in segments)}"
def __reload_module__():
global ytdl
ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)