refactor: follow more guidelines

This commit is contained in:
Ryan 2025-04-03 17:32:44 -04:00
parent f360566824
commit 7f62b0f273
Signed by: ErrorNoInternet
GPG Key ID: 2486BFB7B1E6A4A3
15 changed files with 100 additions and 71 deletions

View File

@ -8,7 +8,9 @@ import utils
class ArgumentParser:
def __init__(self, command, description):
self.parser = argparse.ArgumentParser(
command, description=description, exit_on_error=False
command,
description=description,
exit_on_error=False,
)
def print_help(self):
@ -26,21 +28,20 @@ class ArgumentParser:
async def parse_args(self, message, tokens) -> argparse.Namespace | None:
try:
with contextlib.redirect_stdout(io.StringIO()):
args = self.parser.parse_args(tokens[1:])
return args
return self.parser.parse_args(tokens[1:])
except SystemExit:
await utils.reply(message, f"```\n{self.print_help()}```")
except Exception as e:
await utils.reply(message, f"`{e}`")
def range_type(string: str, min=0, max=100):
def range_type(string: str, lower=0, upper=100):
try:
value = int(string)
except ValueError:
raise argparse.ArgumentTypeError("value is not a valid integer")
if min <= value <= max:
if lower <= value <= upper:
return value
else:
raise argparse.ArgumentTypeError(f"value is not in range {min}-{max}")
raise argparse.ArgumentTypeError(f"value is not in range {lower}-{upper}")

View File

@ -1,6 +1,6 @@
import collections
from dataclasses import dataclass
from typing import Optional
from typing import ClassVar, Optional
import disnake
@ -16,20 +16,19 @@ class Song:
trigger_message: disnake.Message
def format(self, show_queuer=False, hide_preview=False, multiline=False) -> str:
title = f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})"
duration = (
format_duration(self.player.duration) if self.player.duration else "stream"
)
if multiline:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''})\n**duration:** {format_duration(self.player.duration) if self.player.duration else '[stream]'}"
+ (
f", **queued by:** <@{self.trigger_message.author.id}>"
if show_queuer
else ""
)
)
else:
return (
f"[`{self.player.title}`]({'<' if hide_preview else ''}{self.player.original_url}{'>' if hide_preview else ''}) [**{format_duration(self.player.duration) if self.player.duration else 'stream'}**]"
+ (f" (<@{self.trigger_message.author.id}>)" if show_queuer else "")
return f"{title}\n**duration:** {duration}" + (
f", **queued by:** <@{self.trigger_message.author.id}>"
if show_queuer
else ""
)
return f"{title} [**{duration}**]" + (
f" (<@{self.trigger_message.author.id}>)" if show_queuer else ""
)
def embed(self, is_paused=False):
progress = 0
@ -90,7 +89,7 @@ class Song:
@dataclass
class Player:
queue = collections.deque()
queue: ClassVar = collections.deque()
current: Optional[Song] = None
def queue_pop(self):

View File

@ -13,7 +13,11 @@ ytdl = yt_dlp.YoutubeDL(YTDL_OPTIONS)
class YTDLSource(PCMVolumeTransformer):
def __init__(
self, source: TrackedAudioSource, *, data: dict[str, Any], volume: float = 0.5
self,
source: TrackedAudioSource,
*,
data: dict[str, Any],
volume: float = 0.5,
):
super().__init__(source, volume)
@ -39,7 +43,8 @@ class YTDLSource(PCMVolumeTransformer):
):
loop = loop or asyncio.get_event_loop()
data: Any = await loop.run_in_executor(
None, lambda: ytdl.extract_info(url, download=not stream)
None,
lambda: ytdl.extract_info(url, download=not stream),
)
if "entries" in data:
@ -54,7 +59,7 @@ class YTDLSource(PCMVolumeTransformer):
disnake.FFmpegPCMAudio(
data["url"] if stream else ytdl.prepare_filename(data),
before_options="-vn -reconnect 1",
)
),
),
data=data,
)

View File

@ -98,6 +98,6 @@ async def help(message):
await reply(
message,
", ".join(
[f"`{command.value}`" for command in commands.Command.__members__.values()]
[f"`{command.value}`" for command in commands.Command.__members__.values()],
),
)

View File

@ -41,7 +41,7 @@ async def lookup(message):
embed = disnake.Embed(description=response["description"], color=EMBED_COLOR)
embed.set_thumbnail(
url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp"
url=f"https://cdn.discordapp.com/app-icons/{response['id']}/{response['icon']}.webp",
)
embed.add_field(name="Application Name", value=response["name"])
embed.add_field(name="Application ID", value="`" + response["id"] + "`")
@ -102,7 +102,9 @@ async def lookup(message):
for tag in response["tags"]:
bot_tags += tag + ", "
embed.add_field(
name="Tags", value="None" if bot_tags == "" else bot_tags[:-2], inline=False
name="Tags",
value="None" if bot_tags == "" else bot_tags[:-2],
inline=False,
)
else:
try:
@ -165,7 +167,7 @@ async def clear(message):
)
parser.add_argument(
"count",
type=lambda c: arguments.range_type(c, min=1, max=1000),
type=lambda c: arguments.range_type(c, lower=1, upper=1000),
help="amount of messages to delete",
)
group = parser.add_mutually_exclusive_group()
@ -259,8 +261,10 @@ async def clear(message):
messages = len(
await message.channel.purge(
limit=args.count, check=check, oldest_first=args.oldest_first
)
limit=args.count,
check=check,
oldest_first=args.oldest_first,
),
)
if not args.delete_command:

View File

@ -42,7 +42,7 @@ def match_token(token: str) -> list[Command]:
filter(
lambda command: command.value == token.lower(),
Command.__members__.values(),
)
),
):
return exact_match
@ -50,7 +50,7 @@ def match_token(token: str) -> list[Command]:
filter(
lambda command: command.value.startswith(token.lower()),
Command.__members__.values(),
)
),
)

View File

@ -13,7 +13,8 @@ from .utils import command_allowed
async def playing(message):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0], "get information about the currently playing song"
tokens[0],
"get information about the currently playing song",
)
parser.add_argument(
"-d",
@ -49,7 +50,7 @@ async def playing(message):
await utils.reply(
message,
embed=players[message.guild.id].current.embed(
is_paused=message.guild.voice_client.is_paused()
is_paused=message.guild.voice_client.is_paused(),
),
)
else:
@ -94,7 +95,7 @@ async def fast_forward(message):
"-s",
"--seconds",
nargs="?",
type=lambda v: arguments.range_type(v, min=0, max=300),
type=lambda v: arguments.range_type(v, lower=0, upper=300),
help="the amount of seconds to fast forward instead",
)
if not (args := await parser.parse_args(message, tokens)):
@ -110,11 +111,12 @@ async def fast_forward(message):
seconds = args.seconds
if not seconds:
video = await sponsorblock.get_segments(
players[message.guild.id].current.player.id
players[message.guild.id].current.player.id,
)
if not video:
await utils.reply(
message, "no sponsorblock segments were found for this video!"
message,
"no sponsorblock segments were found for this video!",
)
return
@ -140,7 +142,7 @@ async def volume(message):
parser.add_argument(
"volume",
nargs="?",
type=lambda v: arguments.range_type(v, min=0, max=150),
type=lambda v: arguments.range_type(v, lower=0, upper=150),
help="the volume level (0 - 150)",
)
if not (args := await parser.parse_args(message, tokens)):

View File

@ -20,14 +20,15 @@ async def queue_or_play(message, edited=False):
tokens = commands.tokenize(message.content)
parser = arguments.ArgumentParser(
tokens[0], "queue a song, list the queue, or resume playback"
tokens[0],
"queue a song, list the queue, or resume playback",
)
parser.add_argument("query", nargs="*", help="yt-dlp URL or query to get song")
parser.add_argument(
"-v",
"--volume",
default=50,
type=lambda v: arguments.range_type(v, min=0, max=150),
type=lambda v: arguments.range_type(v, lower=0, upper=150),
help="the volume level (0 - 150) for the specified song",
)
parser.add_argument(
@ -140,8 +141,8 @@ async def queue_or_play(message, edited=False):
lambda queued: queued.trigger_message.author.id
== message.author.id,
players[message.guild.id].queue,
)
)
),
),
)
>= 5
and not len(message.guild.voice_client.channel.members) == 2
@ -155,7 +156,9 @@ async def queue_or_play(message, edited=False):
try:
async with message.channel.typing():
player = await audio.youtubedl.YTDLSource.from_url(
" ".join(query), loop=client.loop, stream=True
" ".join(query),
loop=client.loop,
stream=True,
)
player.volume = float(args.volume) / 100.0
except Exception as e:
@ -192,7 +195,7 @@ async def queue_or_play(message, edited=False):
[
queued.player.duration if queued.player.duration else 0
for queued in players[message.guild.id].queue
]
],
),
natural=True,
)
@ -217,13 +220,14 @@ async def queue_or_play(message, edited=False):
[
f"**{i + 1}.** {queued.format(show_queuer=True, hide_preview=True, multiline=True)}"
for i, queued in batch
]
],
)
for batch in itertools.batched(
enumerate(players[message.guild.id].queue), 10
enumerate(players[message.guild.id].queue),
10,
)
],
)
),
),
).start(utils.MessageInteractionWrapper(message))
else:

View File

@ -21,7 +21,8 @@ async def sponsorblock_command(message):
video = await sponsorblock.get_segments(players[message.guild.id].current.player.id)
if not video:
await utils.reply(
message, "no sponsorblock segments were found for this video!"
message,
"no sponsorblock segments were found for this video!",
)
return
@ -33,7 +34,7 @@ async def sponsorblock_command(message):
current = "**" if progress >= begin and progress < end else ""
text.append(
f"{current}`{audio.utils.format_duration(begin)}` - `{audio.utils.format_duration(end)}`: {category}{current}"
f"{current}`{audio.utils.format_duration(begin)}` - `{audio.utils.format_duration(end)}`: {category}{current}",
)
await utils.reply(

View File

@ -24,7 +24,8 @@ def play_next(message, once=False, first=False):
if message.guild.id in players and players[message.guild.id].queue:
queued = players[message.guild.id].queue_pop()
message.guild.voice_client.play(
queued.player, after=lambda e: play_after_callback(e, message, once)
queued.player,
after=lambda e: play_after_callback(e, message, once),
)
embed = queued.embed()

View File

@ -72,7 +72,7 @@ async def on_message(message, edited=False):
end = time.time()
debug(
f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s"
f"reloaded {len(reloaded_modules)} modules in {round(end - start, 2)}s",
)
await utils.add_check_reaction(message)

View File

@ -8,11 +8,15 @@ from state import client, kill, players
async def transcript(
message, languages=["en"], max_messages=6, min_messages=3, upper=True
message,
languages=["en"],
max_messages=6,
min_messages=3,
upper=True,
):
initial_id = message.guild.voice_client.source.id
transcript_list = youtube_transcript_api.YouTubeTranscriptApi.list_transcripts(
initial_id
initial_id,
)
try:
transcript = transcript_list.find_manually_created_transcript(languages).fetch()
@ -44,7 +48,7 @@ async def transcript(
await messages.pop().delete()
else:
await message.channel.delete_messages(
[messages.pop() for _ in range(count)]
[messages.pop() for _ in range(count)],
)
except Exception:
pass
@ -77,19 +81,20 @@ def messages_per_second(limit=500):
average = 1
print(
f"I am receiving **{average} {'message' if average == 1 else 'messages'} per second** "
f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**"
f"from **{len(members)} {'member' if len(members) == 1 else 'members'}** across **{len(guilds)} {'guild' if len(guilds) == 1 else 'guilds'}**",
)
async def auto_count(channel_id: int):
if (channel := await client.fetch_channel(channel_id)) and isinstance(
channel, disnake.TextChannel
channel,
disnake.TextChannel,
):
last_message = (await channel.history(limit=1).flatten())[0]
try:
result = str(
int("".join(filter(lambda d: d in string.digits, last_message.content)))
+ 1
+ 1,
)
except Exception:
result = "where number"

View File

@ -14,7 +14,7 @@ categories = json.dumps(
"preview",
"selfpromo",
"sponsor",
]
],
)
@ -30,7 +30,7 @@ async def get_segments(video_id: str):
)
if response.status == 200 and (
results := list(
filter(lambda v: video_id == v["videoID"], await response.json())
filter(lambda v: video_id == v["videoID"], await response.json()),
)
):
sponsorblock_cache[video_id] = results[0]

View File

@ -7,15 +7,15 @@ class TestFilterSecrets(unittest.TestCase):
def test_filter_secrets(self):
secret = "PLACEHOLDER_TOKEN"
self.assertFalse(
secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret})
secret in utils.filter_secrets(f"HELLO{secret}WORLD", {"TOKEN": secret}),
)
self.assertFalse(secret in utils.filter_secrets(secret, {"TOKEN": secret}))
self.assertFalse(
secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret})
secret in utils.filter_secrets(f"123{secret}", {"TOKEN": secret}),
)
self.assertFalse(
secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret})
secret in utils.filter_secrets(f"{secret}{secret}", {"TOKEN": secret}),
)
self.assertFalse(
secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret})
secret in utils.filter_secrets(f"{secret}@#(*&*$)", {"TOKEN": secret}),
)

View File

@ -1,6 +1,6 @@
import os
import time
from logging import error, info
from pathlib import Path
import disnake
@ -34,7 +34,9 @@ async def reply(message, *args, **kwargs):
try:
await message_responses[message.id].edit(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
)
return
except Exception:
@ -42,7 +44,9 @@ async def reply(message, *args, **kwargs):
try:
response = await message.reply(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
)
except Exception:
response = await channel_send(message, *args, **kwargs)
@ -52,13 +56,15 @@ async def reply(message, *args, **kwargs):
async def channel_send(message, *args, **kwargs):
await message.channel.send(
*args, **kwargs, allowed_mentions=disnake.AllowedMentions.none()
*args,
**kwargs,
allowed_mentions=disnake.AllowedMentions.none(),
)
def load_opus():
for path in filter(
lambda p: os.path.exists(p),
lambda p: Path(p).exists(),
["/usr/lib64/libopus.so.0", "/usr/lib/libopus.so.0"],
):
try:
@ -70,8 +76,8 @@ def load_opus():
raise Exception("could not locate working opus library")
def snowflake_timestamp(id):
return round(((id >> 22) + 1420070400000) / 1000)
def snowflake_timestamp(snowflake):
return round(((snowflake >> 22) + 1420070400000) / 1000)
async def add_check_reaction(message):
@ -80,7 +86,8 @@ async def add_check_reaction(message):
async def invalid_user_handler(interaction):
await interaction.response.send_message(
"you are not the intended receiver of this message!", ephemeral=True
"you are not the intended receiver of this message!",
ephemeral=True,
)