refactor: rewrite reload system

This commit is contained in:
Ryan 2024-12-30 14:52:51 -05:00
parent 8985999d6c
commit 6d7b46a7e5
Signed by: ErrorNoInternet
GPG Key ID: 2486BFB7B1E6A4A3
9 changed files with 37 additions and 105 deletions

View File

@ -1,21 +1,6 @@
import importlib
import inspect
from state import reloaded_modules
from . import bot, tools, utils, voice from . import bot, tools, utils, voice
from .utils import * from .utils import *
def __reload_module__(): def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()
globals().update({k: v for k, v in vars(utils).items() if not k.startswith("_")}) globals().update({k: v for k, v in vars(utils).items() if not k.startswith("_")})

View File

@ -1,12 +1,9 @@
import importlib
import inspect
import time import time
import arguments import arguments
import commands import commands
import constants
import utils import utils
from state import reloaded_modules, start_time from state import start_time
async def uptime(message): async def uptime(message):
@ -28,15 +25,3 @@ async def uptime(message):
await utils.reply(message, f"{round(start_time)}") await utils.reply(message, f"{round(start_time)}")
else: else:
await utils.reply(message, f"up {round(time.time() - start_time)} seconds") await utils.reply(message, f"up {round(time.time() - start_time)} seconds")
def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()

View File

@ -1,11 +1,8 @@
import importlib
import inspect
import re import re
import arguments import arguments
import commands import commands
import constants import utils
from state import reloaded_modules
async def clear(message): async def clear(message):
@ -63,15 +60,3 @@ async def clear(message):
) )
except: except:
pass pass
def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()

View File

@ -1,12 +1,8 @@
import importlib
import inspect
import arguments import arguments
import commands import commands
import constants
import utils import utils
import ytdlp import ytdlp
from state import client, playback_queue, reloaded_modules from state import client, playback_queue
async def queue_or_play(message): async def queue_or_play(message):
@ -225,15 +221,3 @@ def generate_queue_list(queue: list):
for i, queued in enumerate(queue) for i, queued in enumerate(queue)
] ]
) )
def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()

View File

@ -1,9 +1,10 @@
import os import os
import sys
EMBED_COLOR = 0xFF6600 EMBED_COLOR = 0xFF6600
OWNERS = [531392146767347712] OWNERS = [531392146767347712]
PREFIX = "%" PREFIX = "%"
RELOAD_BLACKLISTED_MODULES = ["re", "argparse"] RELOAD_BLACKLISTED_MODULES = [*sys.builtin_module_names]
YTDL_OPTIONS = { YTDL_OPTIONS = {
"default_search": "auto", "default_search": "auto",

View File

@ -1,6 +1,4 @@
import contextlib import contextlib
import importlib
import inspect
import io import io
import textwrap import textwrap
import traceback import traceback
@ -10,7 +8,6 @@ import disnake_paginator
import commands import commands
import constants import constants
import utils import utils
from state import reloaded_modules
async def on_message(message): async def on_message(message):
@ -99,15 +96,3 @@ async def on_message(message):
message, message,
f"exception occurred while processing command: ```\n{''.join(traceback.format_exception(e)).replace('`', '\\`')}```", f"exception occurred while processing command: ```\n{''.join(traceback.format_exception(e)).replace('`', '\\`')}```",
) )
def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()

43
main.py
View File

@ -1,3 +1,4 @@
import contextlib
import importlib import importlib
import inspect import inspect
import time import time
@ -6,7 +7,7 @@ import commands
import constants import constants
import core import core
import events import events
from state import client, reloaded_modules, start_time from state import client, start_time
@client.event @client.event
@ -33,20 +34,40 @@ async def on_message(message):
if message.author.id in constants.OWNERS and commands.match(message.content) == [ if message.author.id in constants.OWNERS and commands.match(message.content) == [
commands.Command.RELOAD commands.Command.RELOAD
]: ]:
for name, module in globals().items(): reloaded_modules = set()
if ( for module in filter(
inspect.ismodule(module) lambda v: inspect.ismodule(v)
and name not in constants.RELOAD_BLACKLISTED_MODULES and v.__name__ not in constants.RELOAD_BLACKLISTED_MODULES,
): globals().values(),
importlib.reload(module) ):
if "__reload_module__" in dir(module) and name not in reloaded_modules: rreload(reloaded_modules, module)
reloaded_modules.add(name)
module.__reload_module__()
reloaded_modules.clear()
await message.add_reaction("") await message.add_reaction("")
return return
await events.on_message(message) await events.on_message(message)
def rreload(reloaded_modules, module):
reloaded_modules.add(module)
importlib.reload(module)
if "__reload_module__" in dir(module):
module.__reload_module__()
with contextlib.suppress(AttributeError):
for module in filter(
lambda m: m.__spec__.origin != "frozen",
filter(
lambda v: inspect.ismodule(v)
and (
v.__name__.split(".")[-1]
not in constants.RELOAD_BLACKLISTED_MODULES
)
and (v not in reloaded_modules),
map(lambda attr: getattr(module, attr), dir(module)),
),
):
rreload(reloaded_modules, module)
client.run(constants.SECRETS["TOKEN"]) client.run(constants.SECRETS["TOKEN"])

View File

@ -5,7 +5,6 @@ import disnake
start_time = time.time() start_time = time.time()
playback_queue = {} playback_queue = {}
reloaded_modules = set()
intents = disnake.Intents.default() intents = disnake.Intents.default()
intents.message_content = True intents.message_content = True

View File

@ -1,13 +1,10 @@
import asyncio import asyncio
import importlib
import inspect
from typing import Any, Optional from typing import Any, Optional
import disnake import disnake
import yt_dlp import yt_dlp
import constants import constants
from state import reloaded_modules
ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS) ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS)
@ -45,15 +42,5 @@ class YTDLSource(disnake.PCMVolumeTransformer):
def __reload_module__(): def __reload_module__():
for name, module in globals().items():
if (
inspect.ismodule(module)
and name not in constants.RELOAD_BLACKLISTED_MODULES
):
importlib.reload(module)
if "__reload_module__" in dir(module) and name not in reloaded_modules:
reloaded_modules.add(name)
module.__reload_module__()
global ytdl global ytdl
ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS) ytdl = yt_dlp.YoutubeDL(constants.YTDL_OPTIONS)