refactor: follow more guidelines

This commit is contained in:
Ryan 2025-04-03 17:32:44 -04:00
parent f360566824
commit 7f62b0f273
Signed by: ErrorNoInternet
GPG Key ID: 2486BFB7B1E6A4A3
15 changed files with 100 additions and 71 deletions

View File

@ -8,7 +8,9 @@ import utils
class ArgumentParser: class ArgumentParser:
def __init__(self, command, description): def __init__(self, command, description):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
command, description=description, exit_on_error=False command,
description=description,
exit_on_error=False,
) )
def print_help(self): def print_help(self):
@ -26,21 +28,20 @@ class ArgumentParser:
async def parse_args(self, message, tokens) -> argparse.Namespace | None: async def parse_args(self, message, tokens) -> argparse.Namespace | None:
try: try:
with contextlib.redirect_stdout(io.StringIO()): with contextlib.redirect_stdout(io.StringIO()):
args = self.parser.parse_args(tokens[1:]) return self.parser.parse_args(tokens[1:])
return args
except SystemExit: except SystemExit:
await utils.reply(message, f"```\n{self.print_help()}```") await utils.reply(message, f"```\n{self.print_help()}```")
except Exception as e: except Exception as e:
await utils.reply(message, f"`{e}`") await utils.reply(message, f"`{e}`")
def range_type(string: str, min=0, max=100): def range_type(string: str, lower=0, upper=100):
try: try:
value = int(string) value = int(string)
except ValueError: except ValueError:
raise argparse.ArgumentTypeError("value is not a valid integer") raise argparse.ArgumentTypeError("value is not a valid integer")
if min <= value <= max: if lower <= value <= upper:
return value return value
else:
raise argparse.ArgumentTypeError(f"value is not in range {min}-{max}") raise argparse.ArgumentTypeError(f"value is not in range {lower}-{upper}")

View File

@ -1,6 +1,6 @@
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import disnake import disnake
@ -16,19 +16,18 @@ class Song:
trigger_message: disnake.Message trigger_message: disnake.Message
def format(self, show_queuer=False, hide_preview=False, multiline=False) -> str: def format(self, show_queuer=False, hide_preview=False, multiline=False) -> str:
title = f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})"
duration = (
format_duration(self.player.duration) if self.player.duration else "stream"
)
if multiline: if multiline:
return ( return f"{title}\n**duration:** {duration}" + (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})\n**duration:** {format_duration(self.player.duration) if self.player.duration else '[stream]'}"
+ (
f", **queued by:** <@{self.trigger_message.author.id}>" f", **queued by:** <@{self.trigger_message.author.id}>"
if show_queuer if show_queuer
else "" else ""
) )
) return f"{title} [**{duration}**]" + (
else: f" (<@{self.trigger_message.author.id}>)" if show_queuer else ""
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) [**{format_duration(self.player.duration) if self.player.duration else 'stream'}**]"
+ (f" (<@{self.trigger_message.author.id}>)" if show_queuer else "")
) )
def embed(self, is_paused=False): def embed(self, is_paused=False):
@ -90,7 +89,7 @@ class Song:
@dataclass @dataclass
class Player: class Player:
queue = collections.deque() queue: ClassVar = collections.deque()
current: Optional[Song] = None current: Optional[Song] = None
def queue_pop(self): def queue_pop(self):

View File

@ -13,7 +13,11 @@ ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)
class YTDLSource(PCMVolumeTransformer): class YTDLSource(PCMVolumeTransformer):
def __init__( def __init__(
self, source: TrackedAudioSource, *, data: dict[str, Any], volume: float = 0.5 self,
source: TrackedAudioSource,
*,
data: dict[str, Any],
volume: float = 0.5,
): ):
super().__init__(source, volume) super().__init__(source, volume)
@ -39,7 +43,8 @@ class YTDLSource(PCMVolumeTransformer):
): ):
loop = loop or asyncio.get_event_loop() loop = loop or asyncio.get_event_loop()
data: Any = await loop.run_in_executor( data: Any = await loop.run_in_executor(
None, lambda: ytdl.extract_info(url, download=not stream) None,
lambda: ytdl.extract_info(url, download=not stream),
) )
if "entries" in data: if "entries" in data:
@ -54,7 +59,7 @@ class YTDLSource(PCMVolumeTransformer):
disnake.FFmpegPCMAudio( disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data), data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1", before_options="-vn -reconnect 1",
) ),
), ),
data=data, data=data,
) )

View File

@ -98,6 +98,6 @@ async def help(message):
await reply( await reply(
message, message,
", ".join( ", ".join(
[f"`{command.value}`" for command in commands.Command.__members__.values()] [f"`{command.value}`" for command in commands.Command.__members__.values()],
), ),
) )

View File

@ -41,7 +41,7 @@ async def lookup(message):
embed = disnake.Embed(description=response["description"], color=EMBED_COLOR) embed = disnake.Embed(description=response["description"], color=EMBED_COLOR)
embed.set_thumbnail( embed.set_thumbnail(
url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp" url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp",
) )
embed.add_field(name="Application Name", value=response["name"]) embed.add_field(name="Application Name", value=response["name"])
embed.add_field(name="Application ID", value="`" + response["id"] + "`") embed.add_field(name="Application ID", value="`" + response["id"] + "`")
@ -102,7 +102,9 @@ async def lookup(message):
for tag in response["tags"]: for tag in response["tags"]:
bot_tags += tag + ", " bot_tags += tag + ", "
embed.add_field( embed.add_field(
name="Tags", value="None" if bot_tags == "" else bot_tags[:-2], inline=False name="Tags",
value="None" if bot_tags == "" else bot_tags[:-2],
inline=False,
) )
else: else:
try: try:
@ -165,7 +167,7 @@ async def clear(message):
) )
parser.add_argument( parser.add_argument(
"count", "count",
type=lambda c: arguments.range_type(c, min=1, max=1000), type=lambda c: arguments.range_type(c, lower=1, upper=1000),
help="amount of messages to delete", help="amount of messages to delete",
) )
group = parser.add_mutually_exclusive_group() group = parser.add_mutually_exclusive_group()
@ -259,8 +261,10 @@ async def clear(message):
messages = len( messages = len(
await message.channel.purge( await message.channel.purge(
limit=args.count, check=check, oldest_first=args.oldest_first limit=args.count,
) check=check,
oldest_first=args.oldest_first,
),
) )
if not args.delete_command: if not args.delete_command:

View File

@ -42,7 +42,7 @@ def match_token(token: str) -> list[Command]:
filter( filter(
lambda command: command.value == token.lower(), lambda command: command.value == token.lower(),
Command.__members__.values(), Command.__members__.values(),
) ),
): ):
return exact_match return exact_match
@ -50,7 +50,7 @@ def match_token(token: str) -> list[Command]:
filter( filter(
lambda command: command.value.startswith(token.lower()), lambda command: command.value.startswith(token.lower()),
Command.__members__.values(), Command.__members__.values(),
) ),
) )

View File

@ -13,7 +13,8 @@ from .utils import command_allowed
async def playing(message): async def playing(message):
tokens = commands.tokenize(message.content) tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser( parser = arguments.ArgumentParser(
tokens[0], "get information about the currently playing song" tokens[0],
"get information about the currently playing song",
) )
parser.add_argument( parser.add_argument(
"-d", "-d",
@ -49,7 +50,7 @@ async def playing(message):
await utils.reply( await utils.reply(
message, message,
embed=players[message.guild.id].current.embed( embed=players[message.guild.id].current.embed(
is_paused=message.guild.voice_client.is_paused() is_paused=message.guild.voice_client.is_paused(),
), ),
) )
else: else:
@ -94,7 +95,7 @@ async def fast_forward(message):
"-s", "-s",
"--seconds", "--seconds",
nargs="?", nargs="?",
type=lambda v: arguments.range_type(v, min=0, max=300), type=lambda v: arguments.range_type(v, lower=0, upper=300),
help="the amount of seconds to fast forward instead", help="the amount of seconds to fast forward instead",
) )
if not (args := await parser.parse_args(message, tokens)): if not (args := await parser.parse_args(message, tokens)):
@ -110,11 +111,12 @@ async def fast_forward(message):
seconds = args.seconds seconds = args.seconds
if not seconds: if not seconds:
video = await sponsorblock.get_segments( video = await sponsorblock.get_segments(
players[message.guild.id].current.player.id players[message.guild.id].current.player.id,
) )
if not video: if not video:
await utils.reply( await utils.reply(
message, "no sponsorblock segments were found for this video!" message,
"no sponsorblock segments were found for this video!",
) )
return return
@ -140,7 +142,7 @@ async def volume(message):
parser.add_argument( parser.add_argument(
"volume", "volume",
nargs="?", nargs="?",
type=lambda v: arguments.range_type(v, min=0, max=150), type=lambda v: arguments.range_type(v, lower=0, upper=150),
help="the volume level (0 - 150)", help="the volume level (0 - 150)",
) )
if not (args := await parser.parse_args(message, tokens)): if not (args := await parser.parse_args(message, tokens)):

View File

@ -20,14 +20,15 @@ async def queue_or_play(message, edited=False):
tokens = commands.tokenize(message.content) tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser( parser = arguments.ArgumentParser(
tokens[0], "queue a song, list the queue, or resume playback" tokens[0],
"queue a song, list the queue, or resume playback",
) )
parser.add_argument("query", nargs="*", help="yt-dlp URL or query to get song") parser.add_argument("query", nargs="*", help="yt-dlp URL or query to get song")
parser.add_argument( parser.add_argument(
"-v", "-v",
"--volume", "--volume",
default=50, default=50,
type=lambda v: arguments.range_type(v, min=0, max=150), type=lambda v: arguments.range_type(v, lower=0, upper=150),
help="the volume level (0 - 150) for the specified song", help="the volume level (0 - 150) for the specified song",
) )
parser.add_argument( parser.add_argument(
@ -140,8 +141,8 @@ async def queue_or_play(message, edited=False):
lambda queued: queued.trigger_message.author.id lambda queued: queued.trigger_message.author.id
== message.author.id, == message.author.id,
players[message.guild.id].queue, players[message.guild.id].queue,
) ),
) ),
) )
>= 5 >= 5
and not len(message.guild.voice_client.channel.members) == 2 and not len(message.guild.voice_client.channel.members) == 2
@ -155,7 +156,9 @@ async def queue_or_play(message, edited=False):
try: try:
async with message.channel.typing(): async with message.channel.typing():
player = await audio.youtubedl.YTDLSource.from_url( player = await audio.youtubedl.YTDLSource.from_url(
" ".join(query), loop=client.loop, stream=True " ".join(query),
loop=client.loop,
stream=True,
) )
player.volume = float(args.volume) / 100.0 player.volume = float(args.volume) / 100.0
except Exception as e: except Exception as e:
@ -192,7 +195,7 @@ async def queue_or_play(message, edited=False):
[ [
queued.player.duration if queued.player.duration else 0 queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue for queued in players[message.guild.id].queue
] ],
), ),
natural=True, natural=True,
) )
@ -217,13 +220,14 @@ async def queue_or_play(message, edited=False):
[ [
f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}" f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}"
for i, queued in batch for i, queued in batch
]
)
for batch in itertools.batched(
enumerate(players[message.guild.id].queue), 10
)
], ],
) )
for batch in itertools.batched(
enumerate(players[message.guild.id].queue),
10,
)
],
),
), ),
).start(utils.MessageInteractionWrapper(message)) ).start(utils.MessageInteractionWrapper(message))
else: else:

View File

@ -21,7 +21,8 @@ async def sponsorblock_command(message):
video = await sponsorblock.get_segments(players[message.guild.id].current.player.id) video = await sponsorblock.get_segments(players[message.guild.id].current.player.id)
if not video: if not video:
await utils.reply( await utils.reply(
message, "no sponsorblock segments were found for this video!" message,
"no sponsorblock segments were found for this video!",
) )
return return
@ -33,7 +34,7 @@ async def sponsorblock_command(message):
current = "**" if progress >= begin and progress < end else "" current = "**" if progress >= begin and progress < end else ""
text.append( text.append(
f"{current}`{audio.utils.format_duration(begin)}` - `{audio.utils.format_duration(end)}`: {category}{current}" f"{current}`{audio.utils.format_duration(begin)}` - `{audio.utils.format_duration(end)}`: {category}{current}",
) )
await utils.reply( await utils.reply(

View File

@ -24,7 +24,8 @@ def play_next(message, once=False, first=False):
if message.guild.id in players and players[message.guild.id].queue: if message.guild.id in players and players[message.guild.id].queue:
queued = players[message.guild.id].queue_pop() queued = players[message.guild.id].queue_pop()
message.guild.voice_client.play( message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once) queued.player,
after=lambda e: play_after_callback(e, message, once),
) )
embed = queued.embed() embed = queued.embed()

View File

@ -72,7 +72,7 @@ async def on_message(message, edited=False):
end = time.time() end = time.time()
debug( debug(
f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s" f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s",
) )
await utils.add_check_reaction(message) await utils.add_check_reaction(message)

View File

@ -8,11 +8,15 @@ from state import client, kill, players
async def transcript( async def transcript(
message, languages=["en"], max_messages=6, min_messages=3, upper=True message,
languages=["en"],
max_messages=6,
min_messages=3,
upper=True,
): ):
initial_id = message.guild.voice_client.source.id initial_id = message.guild.voice_client.source.id
transcript_list = youtube_transcript_api.YouTubeTranscriptApi.list_transcripts( transcript_list = youtube_transcript_api.YouTubeTranscriptApi.list_transcripts(
initial_id initial_id,
) )
try: try:
transcript = transcript_list.find_manually_created_transcript(languages).fetch() transcript = transcript_list.find_manually_created_transcript(languages).fetch()
@ -44,7 +48,7 @@ async def transcript(
await messages.pop().delete() await messages.pop().delete()
else: else:
await message.channel.delete_messages( await message.channel.delete_messages(
[messages.pop() for _ in range(count)] [messages.pop() for _ in range(count)],
) )
except Exception: except Exception:
pass pass
@ -77,19 +81,20 @@ def messages_per_second(limit=500):
average = 1 average = 1
print( print(
f"I am receiving **{average} {'message' if average == 1 else 'messages'} per second** " f"I am receiving **{average} {'message' if average == 1 else 'messages'} per second** "
f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**" f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**",
) )
async def auto_count(channel_id: int): async def auto_count(channel_id: int):
if (channel := await client.fetch_channel(channel_id)) and isinstance( if (channel := await client.fetch_channel(channel_id)) and isinstance(
channel, disnake.TextChannel channel,
disnake.TextChannel,
): ):
last_message = (await channel.history(limit=1).flatten())[0] last_message = (await channel.history(limit=1).flatten())[0]
try: try:
result = str( result = str(
int("".join(filter(lambda d: d in string.digits, last_message.content))) int("".join(filter(lambda d: d in string.digits, last_message.content)))
+ 1 + 1,
) )
except Exception: except Exception:
result = "where number" result = "where number"

View File

@ -14,7 +14,7 @@ categories = json.dumps(
"preview", "preview",
"selfpromo", "selfpromo",
"sponsor", "sponsor",
] ],
) )
@ -30,7 +30,7 @@ async def get_segments(video_id: str):
) )
if response.status == 200 and ( if response.status == 200 and (
results := list( results := list(
filter(lambda v: video_id == v["videoID"], await response.json()) filter(lambda v: video_id == v["videoID"], await response.json()),
) )
): ):
sponsorblock_cache[video_id] = results[0] sponsorblock_cache[video_id] = results[0]

View File

@ -7,15 +7,15 @@ class TestFilterSecrets(unittest.TestCase):
def test_filter_secrets(self): def test_filter_secrets(self):
secret = "PLACEHOLDER_TOKEN" secret = "PLACEHOLDER_TOKEN"
self.assertFalse( self.assertFalse(
secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret}) secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret}),
) )
self.assertFalse(secret in utils.filter_secrets(secret, {"TOKEN": secret})) self.assertFalse(secret in utils.filter_secrets(secret, {"TOKEN": secret}))
self.assertFalse( self.assertFalse(
secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret}) secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret}),
) )
self.assertFalse( self.assertFalse(
secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret}) secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret}),
) )
self.assertFalse( self.assertFalse(
secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret}) secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret}),
) )

View File

@ -1,6 +1,6 @@
import os
import time import time
from logging import error, info from logging import error, info
from pathlib import Path
import disnake import disnake
@ -34,7 +34,9 @@ async def reply(message, *args, **kwargs):
try: try:
await message_responses[message.id].edit( await message_responses[message.id].edit(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none() *args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
) )
return return
except Exception: except Exception:
@ -42,7 +44,9 @@ async def reply(message, *args, **kwargs):
try: try:
response = await message.reply( response = await message.reply(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none() *args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
) )
except Exception: except Exception:
response = await channel_send(message, *args, **kwargs) response = await channel_send(message, *args, **kwargs)
@ -52,13 +56,15 @@ async def reply(message, *args, **kwargs):
async def channel_send(message, *args, **kwargs): async def channel_send(message, *args, **kwargs):
await message.channel.send( await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none() *args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
) )
def load_opus(): def load_opus():
for path in filter( for path in filter(
lambda p: os.path.exists(p), lambda p: Path(p).exists(),
["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"], ["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"],
): ):
try: try:
@ -70,8 +76,8 @@ def load_opus():
raise Exception("could not locate working opus library") raise Exception("could not locate working opus library")
def snowflake_timestamp(id): def snowflake_timestamp(snowflake):
return round(((id >> 22) + 1420070400000) / 1000) return round(((snowflake >> 22) + 1420070400000) / 1000)
async def add_check_reaction(message): async def add_check_reaction(message):
@ -80,7 +86,8 @@ async def add_check_reaction(message):
async def invalid_user_handler(interaction): async def invalid_user_handler(interaction):
await interaction.response.send_message( await interaction.response.send_message(
"you are not the intended receiver of this message!", ephemeral=True "you are not the intended receiver of this message!",
ephemeral=True,
) )