add basic typing
This commit is contained in:
parent
604a1378b7
commit
6b7949e1a1
1 changed files with 132 additions and 81 deletions
213
delojza.py
213
delojza.py
|
@ -16,7 +16,9 @@ from datetime import datetime, timedelta
|
|||
from glob import glob
|
||||
from operator import itemgetter
|
||||
from random import random
|
||||
from sqlite3.dbapi2 import Connection
|
||||
from time import sleep
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
|
||||
import acoustid
|
||||
import filetype
|
||||
|
@ -26,8 +28,12 @@ import pytumblr
|
|||
import requests
|
||||
import telegram
|
||||
import youtube_dl
|
||||
from _typeshed import StrPath
|
||||
from markovify.text import Text
|
||||
from mutagen import File, FileType
|
||||
from mutagen.easyid3 import EasyID3
|
||||
from telegram.ext import Updater, CommandHandler, MessageHandler
|
||||
from requests.api import options
|
||||
from telegram.ext import CommandHandler, MessageHandler, Updater
|
||||
from youtube_dl import DownloadError
|
||||
from youtube_dl.version import __version__ as YTDL_VERSION
|
||||
|
||||
|
@ -42,33 +48,41 @@ def mkdir_p(path):
|
|||
raise
|
||||
|
||||
|
||||
def datestr(date):
|
||||
def datestr(date: datetime):
|
||||
return date.strftime("%Y-%m-%d@%H%M")
|
||||
|
||||
|
||||
class DelojzaDB:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
self.db = None
|
||||
self.db: Optional[Connection] = None
|
||||
|
||||
def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = sqlite3.connect(self.db_path)
|
||||
|
||||
def get_protected_tags(self):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
results = self.db.execute("SELECT tag FROM tags WHERE protected == 1")
|
||||
return [res[0] for res in results.fetchall()]
|
||||
|
||||
def get_protected_chats(self):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
results = self.db.execute("SELECT id FROM chats WHERE protected == 1")
|
||||
return [res[0] for res in results.fetchall()]
|
||||
|
||||
def get_chat(self, id):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
return self.db.execute(
|
||||
"SELECT id, protected FROM chats WHERE id == ?", (id,)
|
||||
).fetchone()
|
||||
|
||||
def set_chat_protected(self, id, protected):
|
||||
def set_chat_protected(self, id: str, protected: bool):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
chat_in_db = self.get_chat(id)
|
||||
if chat_in_db:
|
||||
self.db.execute(
|
||||
|
@ -80,12 +94,16 @@ class DelojzaDB:
|
|||
)
|
||||
self.db.commit()
|
||||
|
||||
def get_tag(self, tag):
|
||||
def get_tag(self, tag: str):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
return self.db.execute(
|
||||
"SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,)
|
||||
).fetchone()
|
||||
|
||||
def set_tag_protected(self, tag, protected):
|
||||
def set_tag_protected(self, tag: str, protected: bool):
|
||||
if self.db is None:
|
||||
raise RuntimeError("Database not initialized!")
|
||||
tag_in_db = self.get_tag(tag)
|
||||
if tag_in_db:
|
||||
self.db.execute(
|
||||
|
@ -98,19 +116,40 @@ class DelojzaDB:
|
|||
self.db.commit()
|
||||
|
||||
|
||||
class MarkovBlabberer:
|
||||
def __init__(self, filepath: StrPath):
|
||||
self.logger = logging.getLogger("markov")
|
||||
self.filepath = filepath
|
||||
|
||||
with open(filepath) as f:
|
||||
text = f.read()
|
||||
self.markov: Text = markovify.NewlineText(text.lower())
|
||||
self.logger.info("Sentence of the day: " + self.make_sentence())
|
||||
|
||||
def make_sentence(self, tries: int = 100):
|
||||
return self.markov.make_sentence(tries=tries) or "???"
|
||||
|
||||
def add_to_corpus(self, text: str):
|
||||
text = text.lower()
|
||||
new_sentence = markovify.NewlineText(text)
|
||||
self.markov = cast(Text, markovify.combine([self.markov, new_sentence]))
|
||||
with open(self.filepath, "a") as f:
|
||||
f.write(text + "\n")
|
||||
|
||||
|
||||
class DelojzaBot:
|
||||
def __init__(
|
||||
self,
|
||||
tg_api_key,
|
||||
out_dir,
|
||||
redirects=None,
|
||||
tmp_dir=None,
|
||||
db_path=None,
|
||||
protected_password=None,
|
||||
acoustid_key=None,
|
||||
tumblr_name=None,
|
||||
tumblr_keys=None,
|
||||
markov=None,
|
||||
tg_api_key: str,
|
||||
out_dir: StrPath,
|
||||
redirects: Optional[List[Tuple[str, str]]] = None,
|
||||
tmp_dir: Optional[StrPath] = None,
|
||||
db_path: Optional[StrPath] = None,
|
||||
protected_password: Optional[str] = None,
|
||||
acoustid_key: Optional[str] = None,
|
||||
tumblr_name: Optional[str] = None,
|
||||
tumblr_keys: Optional[Tuple[str, str, str, str]] = None,
|
||||
markov: Optional[MarkovBlabberer] = None,
|
||||
):
|
||||
self._setup_logging(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
|
@ -121,15 +160,16 @@ class DelojzaBot:
|
|||
|
||||
self.out_dir = os.path.abspath(out_dir)
|
||||
self.out_dir = self.out_dir[:-1] if self.out_dir[-1] == "/" else self.out_dir
|
||||
self.logger.debug("OUT_DIR: " + out_dir)
|
||||
self.logger.debug(f"OUT_DIR: {out_dir}")
|
||||
self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir()
|
||||
self.logger.debug("TMP_DIR: " + tmp_dir)
|
||||
self.logger.debug(f"TMP_DIR: {tmp_dir}")
|
||||
self.markov = markov
|
||||
|
||||
self.redirects = {}
|
||||
if redirects is not None:
|
||||
for hashtag, directory in redirects:
|
||||
hashtag = hashtag.upper()
|
||||
directory = str(directory)
|
||||
directory = directory[:-1] if directory[-1] == "/" else directory
|
||||
mkdir_p(directory)
|
||||
self.redirects[hashtag] = directory
|
||||
|
@ -162,14 +202,14 @@ class DelojzaBot:
|
|||
self.last_downloaded = {}
|
||||
self.last_hashtags = {}
|
||||
|
||||
def _setup_logging(self, log_path):
|
||||
def _setup_logging(self, log_path: StrPath):
|
||||
self.logger = logging.getLogger("delojza")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
|
||||
dfh = logging.FileHandler(log_path + "/delojza.log")
|
||||
dfh = logging.FileHandler(os.path.join(log_path, "delojza.log"))
|
||||
dfh.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
|
@ -190,7 +230,7 @@ class DelojzaBot:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def ytdl_can(url):
|
||||
def ytdl_can(url: str):
|
||||
ies = youtube_dl.extractor.gen_extractors()
|
||||
for ie in ies:
|
||||
if ie.suitable(url) and ie.IE_NAME != "generic" and "/channel/" not in url:
|
||||
|
@ -200,18 +240,18 @@ class DelojzaBot:
|
|||
|
||||
# https://github.com/django/django/blob/master/django/utils/text.py#L393
|
||||
@staticmethod
|
||||
def sanitize(filepath):
|
||||
if filepath is None:
|
||||
return None
|
||||
filepath = (
|
||||
unicodedata.normalize("NFKD", filepath)
|
||||
def sanitize(text: str):
|
||||
if text is None:
|
||||
return ""
|
||||
text = (
|
||||
unicodedata.normalize("NFKD", text)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii")
|
||||
)
|
||||
return re.sub(r"[^\w.()\[\]{}#-]", "_", filepath)
|
||||
return re.sub(r"[^\w.()\[\]{}#-]", "_", text)
|
||||
|
||||
@staticmethod
|
||||
def _get_tags(filepath):
|
||||
def _get_tags(filepath: StrPath):
|
||||
try:
|
||||
audio = EasyID3(filepath)
|
||||
return (
|
||||
|
@ -222,11 +262,13 @@ class DelojzaBot:
|
|||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def _tag_file(filepath, artist, title):
|
||||
def _tag_file(filepath: StrPath, artist: Optional[str], title: str):
|
||||
try:
|
||||
id3 = mutagen.id3.ID3(filepath)
|
||||
except mutagen.id3.ID3NoHeaderError:
|
||||
mutafile = mutagen.File(filepath)
|
||||
mutafile = cast(Optional[FileType], File(filepath))
|
||||
if not mutafile:
|
||||
return
|
||||
mutafile.add_tags()
|
||||
mutafile.save()
|
||||
id3 = mutagen.id3.ID3(filepath)
|
||||
|
@ -258,7 +300,7 @@ class DelojzaBot:
|
|||
score, rid, aid_title, aid_artist = results[0]
|
||||
if score > 0.4:
|
||||
title = aid_title
|
||||
artist = re.sub(r" *; +", " & ", aid_artist)
|
||||
artist = re.sub(r" *; +", " & ", aid_artist or "")
|
||||
best_acoustid_score = score
|
||||
source = "AcoustID ({}%)".format(round(score * 100))
|
||||
except acoustid.NoBackendError:
|
||||
|
@ -296,7 +338,7 @@ class DelojzaBot:
|
|||
artist = artist.strip() if artist else None
|
||||
title = title.strip() if title else None
|
||||
|
||||
if title is None and artist is None:
|
||||
if title is None:
|
||||
message.reply_text("Tried tagging, found nothing :(")
|
||||
return
|
||||
|
||||
|
@ -309,7 +351,7 @@ class DelojzaBot:
|
|||
self._tag_file(filepath, artist, title)
|
||||
|
||||
@staticmethod
|
||||
def _get_percent_filled(directory):
|
||||
def _get_percent_filled(directory: str):
|
||||
output = subprocess.check_output(["df", directory])
|
||||
percents_re = re.search(r"[0-9]+%", output.decode("utf-8"))
|
||||
if not percents_re:
|
||||
|
@ -317,7 +359,15 @@ class DelojzaBot:
|
|||
return int(percents_re.group(0)[:-1])
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def download_ytdl(self, urls, out_path, date, message, audio=False, filetitle=None):
|
||||
def download_ytdl(
|
||||
self,
|
||||
urls: List[str],
|
||||
out_path: StrPath,
|
||||
date: datetime,
|
||||
message: telegram.Message,
|
||||
audio: bool = False,
|
||||
filetitle: Optional[str] = None,
|
||||
):
|
||||
ytdl = {
|
||||
"noplaylist": True,
|
||||
"restrictfilenames": True,
|
||||
|
@ -352,7 +402,7 @@ class DelojzaBot:
|
|||
else:
|
||||
raise exc
|
||||
for info in [ytdl.extract_info(url, download=False) for url in urls]:
|
||||
filename = ytdl.prepare_filename(info)
|
||||
filename = cast(str, ytdl.prepare_filename(info))
|
||||
globbeds = glob(os.path.splitext(filename)[0] + ".*")
|
||||
for globbed in globbeds:
|
||||
if globbed.endswith("mp3"):
|
||||
|
@ -362,7 +412,15 @@ class DelojzaBot:
|
|||
filenames.append(dest)
|
||||
return filenames
|
||||
|
||||
def download_raw(self, urls, out_path, date, message, audio=False, filetitle=None):
|
||||
def download_raw(
|
||||
self,
|
||||
urls: List[str],
|
||||
out_path: StrPath,
|
||||
date: datetime,
|
||||
message: telegram.Message,
|
||||
audio: bool = False,
|
||||
filetitle: Optional[str] = None,
|
||||
):
|
||||
filenames = []
|
||||
for url in urls:
|
||||
local_filename = os.path.join(
|
||||
|
@ -409,7 +467,7 @@ class DelojzaBot:
|
|||
return filenames
|
||||
|
||||
@staticmethod
|
||||
def extract_hashtags(message):
|
||||
def extract_hashtags(message: telegram.Message):
|
||||
hashtags = list(
|
||||
map(
|
||||
message.parse_entity,
|
||||
|
@ -429,7 +487,7 @@ class DelojzaBot:
|
|||
hashtags[i] = "PRAS"
|
||||
return hashtags
|
||||
|
||||
def _get_hashtags(self, message):
|
||||
def _get_hashtags(self, message: telegram.Message):
|
||||
hashtags = self.extract_hashtags(message)
|
||||
if len(hashtags) == 0 and self.last_hashtags.get(message.chat.id) is not None:
|
||||
user, ts, last_hashtags = self.last_hashtags[message.chat.id]
|
||||
|
@ -437,7 +495,7 @@ class DelojzaBot:
|
|||
hashtags = last_hashtags
|
||||
return hashtags
|
||||
|
||||
def handle_text(self, message, hashtags):
|
||||
def handle_text(self, message: telegram.Message, hashtags: List[str]):
|
||||
if len(hashtags) == 0 or hashtags[0] not in ("TEXT", "TXT"):
|
||||
return
|
||||
|
||||
|
@ -464,7 +522,14 @@ class DelojzaBot:
|
|||
)
|
||||
|
||||
# noinspection PyBroadException
|
||||
def handle(self, urls, message, hashtags, download_fn, filetitle=None):
|
||||
def handle(
|
||||
self,
|
||||
urls: List[str],
|
||||
message: telegram.Message,
|
||||
hashtags: List[str],
|
||||
download_fn: Any,
|
||||
filetitle=None,
|
||||
):
|
||||
self.db.initialize()
|
||||
|
||||
try:
|
||||
|
@ -605,7 +670,7 @@ class DelojzaBot:
|
|||
else:
|
||||
return False
|
||||
|
||||
def handle_urls(self, message, hashtags):
|
||||
def handle_urls(self, message: telegram.Message, hashtags: List[str]):
|
||||
urls = list(
|
||||
map(
|
||||
lambda e: message.parse_entity(e),
|
||||
|
@ -631,7 +696,7 @@ class DelojzaBot:
|
|||
|
||||
return ytdl_res or raw_res
|
||||
|
||||
def tg_handle(self, bot, update):
|
||||
def tg_handle(self, bot: telegram.Bot, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
hashtags = self._get_hashtags(update.message)
|
||||
if hashtags:
|
||||
|
@ -682,7 +747,7 @@ class DelojzaBot:
|
|||
+ list(self.redirects.keys())
|
||||
)
|
||||
|
||||
def tg_stats(self, _, update):
|
||||
def tg_stats(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
self.db.initialize()
|
||||
if update.message.chat.id not in self.db.get_protected_chats():
|
||||
|
@ -750,7 +815,7 @@ class DelojzaBot:
|
|||
result.append((directory, "NO FILE AT ALL..."))
|
||||
return sorted(result, key=itemgetter(0))
|
||||
|
||||
def tg_orphan(self, _, update):
|
||||
def tg_orphan(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
self.db.initialize()
|
||||
if update.message.chat.id not in self.db.get_protected_chats():
|
||||
|
@ -793,7 +858,7 @@ class DelojzaBot:
|
|||
if len(tmp_reply) > 0:
|
||||
update.message.reply_text(tmp_reply)
|
||||
|
||||
def tg_retag(self, _, update):
|
||||
def tg_retag(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
if self.last_downloaded.get(update.message.chat.id) is not None:
|
||||
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
|
||||
|
@ -817,7 +882,7 @@ class DelojzaBot:
|
|||
orig_artist, orig_title = self._get_tags(mp3)
|
||||
title, artist = orig_artist, orig_title
|
||||
|
||||
self._tag_file(mp3, artist, title)
|
||||
self._tag_file(mp3, artist, cast(str, title))
|
||||
update.message.reply_text(
|
||||
'Tagging "{}" as "{}" by "{}"!'.format(
|
||||
mp3[len(out_dir) + 1 :], title, artist
|
||||
|
@ -829,7 +894,7 @@ class DelojzaBot:
|
|||
+ "???"
|
||||
)
|
||||
|
||||
def tg_delete(self, _, update):
|
||||
def tg_delete(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
if self.last_downloaded.get(update.message.chat.id) is not None:
|
||||
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
|
||||
|
@ -865,7 +930,7 @@ class DelojzaBot:
|
|||
return
|
||||
update.message.reply_text("Nothing to remove!")
|
||||
|
||||
def tg_protect(self, _, update):
|
||||
def tg_protect(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
self.db.initialize()
|
||||
|
||||
|
@ -929,7 +994,7 @@ class DelojzaBot:
|
|||
+ "???"
|
||||
)
|
||||
|
||||
def tg_queue(self, _, update):
|
||||
def tg_queue(self, _, update: telegram.Update):
|
||||
if self.tumblr_client:
|
||||
blog_info = self.tumblr_client.blog_info(self.tumblr_name)
|
||||
update.message.reply_text(
|
||||
|
@ -943,7 +1008,7 @@ class DelojzaBot:
|
|||
)
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def tg_version(self, _, update):
|
||||
def tg_version(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
delojza_date = datetime.fromtimestamp(
|
||||
os.path.getmtime(os.path.realpath(__file__))
|
||||
|
@ -954,13 +1019,15 @@ class DelojzaBot:
|
|||
)
|
||||
)
|
||||
|
||||
def tg_start(self, _, update):
|
||||
def tg_start(self, _, update: telegram.Update):
|
||||
self._log_msg(update)
|
||||
update.message.reply_text(
|
||||
self.markov.make_sentence() if self.markov else "HELLO"
|
||||
)
|
||||
|
||||
def tg_error(self, bot, update, error):
|
||||
def tg_error(
|
||||
self, bot: telegram.Bot, update: telegram.Update, error: telegram.TelegramError
|
||||
):
|
||||
self.logger.error(error)
|
||||
if "Timed out" in str(error):
|
||||
if update is not None:
|
||||
|
@ -980,27 +1047,6 @@ class DelojzaBot:
|
|||
self.updater.idle()
|
||||
|
||||
|
||||
class MarkovBlabberer:
|
||||
def __init__(self, filepath):
|
||||
self.logger = logging.getLogger("markov")
|
||||
self.filepath = filepath
|
||||
|
||||
with open(filepath) as f:
|
||||
text = f.read()
|
||||
self.markov = markovify.NewlineText(text.lower())
|
||||
self.logger.info("Sentence of the day: " + self.make_sentence())
|
||||
|
||||
def make_sentence(self, tries=100):
|
||||
return self.markov.make_sentence(tries=tries) or "???"
|
||||
|
||||
def add_to_corpus(self, text):
|
||||
text = text.lower()
|
||||
new_sentence = markovify.NewlineText(text)
|
||||
self.markov = markovify.combine([self.markov, new_sentence])
|
||||
with open(self.filepath, "a") as f:
|
||||
f.write(text + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
@ -1033,9 +1079,19 @@ if __name__ == "__main__":
|
|||
markov = None
|
||||
|
||||
try:
|
||||
redirects = config.items("redirects")
|
||||
redirects: Optional[List[Tuple[str, str]]] = config.items("redirects")
|
||||
except NoSectionError:
|
||||
redirects = {}
|
||||
redirects = None
|
||||
|
||||
try:
|
||||
tumblr_keys = (
|
||||
config.get("tumblr", "consumer_key"),
|
||||
config.get("tumblr", "consumer_secret"),
|
||||
config.get("tumblr", "oauth_key"),
|
||||
config.get("tumblr", "oauth_secret"),
|
||||
)
|
||||
except (NoSectionError, KeyError):
|
||||
tumblr_keys = None
|
||||
|
||||
delojza = DelojzaBot(
|
||||
config.get("delojza", "tg_api_key"),
|
||||
|
@ -1045,12 +1101,7 @@ if __name__ == "__main__":
|
|||
protected_password=config.get("delojza", "protected_password", fallback=None),
|
||||
acoustid_key=config.get("delojza", "acoustid_api_key", fallback=None),
|
||||
tumblr_name=config.get("tumblr", "blog_name", fallback=None),
|
||||
tumblr_keys=(
|
||||
config.get("tumblr", "consumer_key", fallback=None),
|
||||
config.get("tumblr", "consumer_secret", fallback=None),
|
||||
config.get("tumblr", "oauth_key", fallback=None),
|
||||
config.get("tumblr", "oauth_secret", fallback=None),
|
||||
),
|
||||
tumblr_keys=tumblr_keys,
|
||||
markov=markov,
|
||||
)
|
||||
delojza.run_idle()
|
||||
|
|
Loading…
Add table
Reference in a new issue