add basic typing

This commit is contained in:
Tomáš Mládek 2021-09-14 16:03:55 +02:00
parent 604a1378b7
commit 6b7949e1a1

View file

@ -16,7 +16,9 @@ from datetime import datetime, timedelta
from glob import glob from glob import glob
from operator import itemgetter from operator import itemgetter
from random import random from random import random
from sqlite3.dbapi2 import Connection
from time import sleep from time import sleep
from typing import Any, List, Optional, Tuple, cast
import acoustid import acoustid
import filetype import filetype
@ -26,8 +28,12 @@ import pytumblr
import requests import requests
import telegram import telegram
import youtube_dl import youtube_dl
from _typeshed import StrPath
from markovify.text import Text
from mutagen import File, FileType
from mutagen.easyid3 import EasyID3 from mutagen.easyid3 import EasyID3
from telegram.ext import Updater, CommandHandler, MessageHandler from requests.api import options
from telegram.ext import CommandHandler, MessageHandler, Updater
from youtube_dl import DownloadError from youtube_dl import DownloadError
from youtube_dl.version import __version__ as YTDL_VERSION from youtube_dl.version import __version__ as YTDL_VERSION
@ -42,33 +48,41 @@ def mkdir_p(path):
raise raise
def datestr(date): def datestr(date: datetime):
return date.strftime("%Y-%m-%d@%H%M") return date.strftime("%Y-%m-%d@%H%M")
class DelojzaDB: class DelojzaDB:
def __init__(self, db_path): def __init__(self, db_path):
self.db_path = db_path self.db_path = db_path
self.db = None self.db: Optional[Connection] = None
def initialize(self): def initialize(self):
if self.db is None: if self.db is None:
self.db = sqlite3.connect(self.db_path) self.db = sqlite3.connect(self.db_path)
def get_protected_tags(self): def get_protected_tags(self):
if self.db is None:
raise RuntimeError("Database not initialized!")
results = self.db.execute("SELECT tag FROM tags WHERE protected == 1") results = self.db.execute("SELECT tag FROM tags WHERE protected == 1")
return [res[0] for res in results.fetchall()] return [res[0] for res in results.fetchall()]
def get_protected_chats(self): def get_protected_chats(self):
if self.db is None:
raise RuntimeError("Database not initialized!")
results = self.db.execute("SELECT id FROM chats WHERE protected == 1") results = self.db.execute("SELECT id FROM chats WHERE protected == 1")
return [res[0] for res in results.fetchall()] return [res[0] for res in results.fetchall()]
def get_chat(self, id): def get_chat(self, id):
if self.db is None:
raise RuntimeError("Database not initialized!")
return self.db.execute( return self.db.execute(
"SELECT id, protected FROM chats WHERE id == ?", (id,) "SELECT id, protected FROM chats WHERE id == ?", (id,)
).fetchone() ).fetchone()
def set_chat_protected(self, id, protected): def set_chat_protected(self, id: str, protected: bool):
if self.db is None:
raise RuntimeError("Database not initialized!")
chat_in_db = self.get_chat(id) chat_in_db = self.get_chat(id)
if chat_in_db: if chat_in_db:
self.db.execute( self.db.execute(
@ -80,12 +94,16 @@ class DelojzaDB:
) )
self.db.commit() self.db.commit()
def get_tag(self, tag): def get_tag(self, tag: str):
if self.db is None:
raise RuntimeError("Database not initialized!")
return self.db.execute( return self.db.execute(
"SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,) "SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,)
).fetchone() ).fetchone()
def set_tag_protected(self, tag, protected): def set_tag_protected(self, tag: str, protected: bool):
if self.db is None:
raise RuntimeError("Database not initialized!")
tag_in_db = self.get_tag(tag) tag_in_db = self.get_tag(tag)
if tag_in_db: if tag_in_db:
self.db.execute( self.db.execute(
@ -98,19 +116,40 @@ class DelojzaDB:
self.db.commit() self.db.commit()
class MarkovBlabberer:
def __init__(self, filepath: StrPath):
self.logger = logging.getLogger("markov")
self.filepath = filepath
with open(filepath) as f:
text = f.read()
self.markov: Text = markovify.NewlineText(text.lower())
self.logger.info("Sentence of the day: " + self.make_sentence())
def make_sentence(self, tries: int = 100):
return self.markov.make_sentence(tries=tries) or "???"
def add_to_corpus(self, text: str):
text = text.lower()
new_sentence = markovify.NewlineText(text)
self.markov = cast(Text, markovify.combine([self.markov, new_sentence]))
with open(self.filepath, "a") as f:
f.write(text + "\n")
class DelojzaBot: class DelojzaBot:
def __init__( def __init__(
self, self,
tg_api_key, tg_api_key: str,
out_dir, out_dir: StrPath,
redirects=None, redirects: Optional[List[Tuple[str, str]]] = None,
tmp_dir=None, tmp_dir: Optional[StrPath] = None,
db_path=None, db_path: Optional[StrPath] = None,
protected_password=None, protected_password: Optional[str] = None,
acoustid_key=None, acoustid_key: Optional[str] = None,
tumblr_name=None, tumblr_name: Optional[str] = None,
tumblr_keys=None, tumblr_keys: Optional[Tuple[str, str, str, str]] = None,
markov=None, markov: Optional[MarkovBlabberer] = None,
): ):
self._setup_logging(os.path.dirname(os.path.realpath(__file__))) self._setup_logging(os.path.dirname(os.path.realpath(__file__)))
@ -121,15 +160,16 @@ class DelojzaBot:
self.out_dir = os.path.abspath(out_dir) self.out_dir = os.path.abspath(out_dir)
self.out_dir = self.out_dir[:-1] if self.out_dir[-1] == "/" else self.out_dir self.out_dir = self.out_dir[:-1] if self.out_dir[-1] == "/" else self.out_dir
self.logger.debug("OUT_DIR: " + out_dir) self.logger.debug(f"OUT_DIR: {out_dir}")
self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir() self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir()
self.logger.debug("TMP_DIR: " + tmp_dir) self.logger.debug(f"TMP_DIR: {tmp_dir}")
self.markov = markov self.markov = markov
self.redirects = {} self.redirects = {}
if redirects is not None: if redirects is not None:
for hashtag, directory in redirects: for hashtag, directory in redirects:
hashtag = hashtag.upper() hashtag = hashtag.upper()
directory = str(directory)
directory = directory[:-1] if directory[-1] == "/" else directory directory = directory[:-1] if directory[-1] == "/" else directory
mkdir_p(directory) mkdir_p(directory)
self.redirects[hashtag] = directory self.redirects[hashtag] = directory
@ -162,14 +202,14 @@ class DelojzaBot:
self.last_downloaded = {} self.last_downloaded = {}
self.last_hashtags = {} self.last_hashtags = {}
def _setup_logging(self, log_path): def _setup_logging(self, log_path: StrPath):
self.logger = logging.getLogger("delojza") self.logger = logging.getLogger("delojza")
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
dfh = logging.FileHandler(log_path + "/delojza.log") dfh = logging.FileHandler(os.path.join(log_path, "delojza.log"))
dfh.setLevel(logging.DEBUG) dfh.setLevel(logging.DEBUG)
formatter = logging.Formatter( formatter = logging.Formatter(
@ -190,7 +230,7 @@ class DelojzaBot:
) )
@staticmethod @staticmethod
def ytdl_can(url): def ytdl_can(url: str):
ies = youtube_dl.extractor.gen_extractors() ies = youtube_dl.extractor.gen_extractors()
for ie in ies: for ie in ies:
if ie.suitable(url) and ie.IE_NAME != "generic" and "/channel/" not in url: if ie.suitable(url) and ie.IE_NAME != "generic" and "/channel/" not in url:
@ -200,18 +240,18 @@ class DelojzaBot:
# https://github.com/django/django/blob/master/django/utils/text.py#L393 # https://github.com/django/django/blob/master/django/utils/text.py#L393
@staticmethod @staticmethod
def sanitize(filepath): def sanitize(text: str):
if filepath is None: if text is None:
return None return ""
filepath = ( text = (
unicodedata.normalize("NFKD", filepath) unicodedata.normalize("NFKD", text)
.encode("ascii", "ignore") .encode("ascii", "ignore")
.decode("ascii") .decode("ascii")
) )
return re.sub(r"[^\w.()\[\]{}#-]", "_", filepath) return re.sub(r"[^\w.()\[\]{}#-]", "_", text)
@staticmethod @staticmethod
def _get_tags(filepath): def _get_tags(filepath: StrPath):
try: try:
audio = EasyID3(filepath) audio = EasyID3(filepath)
return ( return (
@ -222,11 +262,13 @@ class DelojzaBot:
return None, None return None, None
@staticmethod @staticmethod
def _tag_file(filepath, artist, title): def _tag_file(filepath: StrPath, artist: Optional[str], title: str):
try: try:
id3 = mutagen.id3.ID3(filepath) id3 = mutagen.id3.ID3(filepath)
except mutagen.id3.ID3NoHeaderError: except mutagen.id3.ID3NoHeaderError:
mutafile = mutagen.File(filepath) mutafile = cast(Optional[FileType], File(filepath))
if not mutafile:
return
mutafile.add_tags() mutafile.add_tags()
mutafile.save() mutafile.save()
id3 = mutagen.id3.ID3(filepath) id3 = mutagen.id3.ID3(filepath)
@ -258,7 +300,7 @@ class DelojzaBot:
score, rid, aid_title, aid_artist = results[0] score, rid, aid_title, aid_artist = results[0]
if score > 0.4: if score > 0.4:
title = aid_title title = aid_title
artist = re.sub(r" *; +", " & ", aid_artist) artist = re.sub(r" *; +", " & ", aid_artist or "")
best_acoustid_score = score best_acoustid_score = score
source = "AcoustID ({}%)".format(round(score * 100)) source = "AcoustID ({}%)".format(round(score * 100))
except acoustid.NoBackendError: except acoustid.NoBackendError:
@ -296,7 +338,7 @@ class DelojzaBot:
artist = artist.strip() if artist else None artist = artist.strip() if artist else None
title = title.strip() if title else None title = title.strip() if title else None
if title is None and artist is None: if title is None:
message.reply_text("Tried tagging, found nothing :(") message.reply_text("Tried tagging, found nothing :(")
return return
@ -309,7 +351,7 @@ class DelojzaBot:
self._tag_file(filepath, artist, title) self._tag_file(filepath, artist, title)
@staticmethod @staticmethod
def _get_percent_filled(directory): def _get_percent_filled(directory: str):
output = subprocess.check_output(["df", directory]) output = subprocess.check_output(["df", directory])
percents_re = re.search(r"[0-9]+%", output.decode("utf-8")) percents_re = re.search(r"[0-9]+%", output.decode("utf-8"))
if not percents_re: if not percents_re:
@ -317,7 +359,15 @@ class DelojzaBot:
return int(percents_re.group(0)[:-1]) return int(percents_re.group(0)[:-1])
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def download_ytdl(self, urls, out_path, date, message, audio=False, filetitle=None): def download_ytdl(
self,
urls: List[str],
out_path: StrPath,
date: datetime,
message: telegram.Message,
audio: bool = False,
filetitle: Optional[str] = None,
):
ytdl = { ytdl = {
"noplaylist": True, "noplaylist": True,
"restrictfilenames": True, "restrictfilenames": True,
@ -352,7 +402,7 @@ class DelojzaBot:
else: else:
raise exc raise exc
for info in [ytdl.extract_info(url, download=False) for url in urls]: for info in [ytdl.extract_info(url, download=False) for url in urls]:
filename = ytdl.prepare_filename(info) filename = cast(str, ytdl.prepare_filename(info))
globbeds = glob(os.path.splitext(filename)[0] + ".*") globbeds = glob(os.path.splitext(filename)[0] + ".*")
for globbed in globbeds: for globbed in globbeds:
if globbed.endswith("mp3"): if globbed.endswith("mp3"):
@ -362,7 +412,15 @@ class DelojzaBot:
filenames.append(dest) filenames.append(dest)
return filenames return filenames
def download_raw(self, urls, out_path, date, message, audio=False, filetitle=None): def download_raw(
self,
urls: List[str],
out_path: StrPath,
date: datetime,
message: telegram.Message,
audio: bool = False,
filetitle: Optional[str] = None,
):
filenames = [] filenames = []
for url in urls: for url in urls:
local_filename = os.path.join( local_filename = os.path.join(
@ -409,7 +467,7 @@ class DelojzaBot:
return filenames return filenames
@staticmethod @staticmethod
def extract_hashtags(message): def extract_hashtags(message: telegram.Message):
hashtags = list( hashtags = list(
map( map(
message.parse_entity, message.parse_entity,
@ -429,7 +487,7 @@ class DelojzaBot:
hashtags[i] = "PRAS" hashtags[i] = "PRAS"
return hashtags return hashtags
def _get_hashtags(self, message): def _get_hashtags(self, message: telegram.Message):
hashtags = self.extract_hashtags(message) hashtags = self.extract_hashtags(message)
if len(hashtags) == 0 and self.last_hashtags.get(message.chat.id) is not None: if len(hashtags) == 0 and self.last_hashtags.get(message.chat.id) is not None:
user, ts, last_hashtags = self.last_hashtags[message.chat.id] user, ts, last_hashtags = self.last_hashtags[message.chat.id]
@ -437,7 +495,7 @@ class DelojzaBot:
hashtags = last_hashtags hashtags = last_hashtags
return hashtags return hashtags
def handle_text(self, message, hashtags): def handle_text(self, message: telegram.Message, hashtags: List[str]):
if len(hashtags) == 0 or hashtags[0] not in ("TEXT", "TXT"): if len(hashtags) == 0 or hashtags[0] not in ("TEXT", "TXT"):
return return
@ -464,7 +522,14 @@ class DelojzaBot:
) )
# noinspection PyBroadException # noinspection PyBroadException
def handle(self, urls, message, hashtags, download_fn, filetitle=None): def handle(
self,
urls: List[str],
message: telegram.Message,
hashtags: List[str],
download_fn: Any,
filetitle=None,
):
self.db.initialize() self.db.initialize()
try: try:
@ -605,7 +670,7 @@ class DelojzaBot:
else: else:
return False return False
def handle_urls(self, message, hashtags): def handle_urls(self, message: telegram.Message, hashtags: List[str]):
urls = list( urls = list(
map( map(
lambda e: message.parse_entity(e), lambda e: message.parse_entity(e),
@ -631,7 +696,7 @@ class DelojzaBot:
return ytdl_res or raw_res return ytdl_res or raw_res
def tg_handle(self, bot, update): def tg_handle(self, bot: telegram.Bot, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
hashtags = self._get_hashtags(update.message) hashtags = self._get_hashtags(update.message)
if hashtags: if hashtags:
@ -682,7 +747,7 @@ class DelojzaBot:
+ list(self.redirects.keys()) + list(self.redirects.keys())
) )
def tg_stats(self, _, update): def tg_stats(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
self.db.initialize() self.db.initialize()
if update.message.chat.id not in self.db.get_protected_chats(): if update.message.chat.id not in self.db.get_protected_chats():
@ -750,7 +815,7 @@ class DelojzaBot:
result.append((directory, "NO FILE AT ALL...")) result.append((directory, "NO FILE AT ALL..."))
return sorted(result, key=itemgetter(0)) return sorted(result, key=itemgetter(0))
def tg_orphan(self, _, update): def tg_orphan(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
self.db.initialize() self.db.initialize()
if update.message.chat.id not in self.db.get_protected_chats(): if update.message.chat.id not in self.db.get_protected_chats():
@ -793,7 +858,7 @@ class DelojzaBot:
if len(tmp_reply) > 0: if len(tmp_reply) > 0:
update.message.reply_text(tmp_reply) update.message.reply_text(tmp_reply)
def tg_retag(self, _, update): def tg_retag(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
if self.last_downloaded.get(update.message.chat.id) is not None: if self.last_downloaded.get(update.message.chat.id) is not None:
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id] files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
@ -817,7 +882,7 @@ class DelojzaBot:
orig_artist, orig_title = self._get_tags(mp3) orig_artist, orig_title = self._get_tags(mp3)
title, artist = orig_artist, orig_title title, artist = orig_artist, orig_title
self._tag_file(mp3, artist, title) self._tag_file(mp3, artist, cast(str, title))
update.message.reply_text( update.message.reply_text(
'Tagging "{}" as "{}" by "{}"!'.format( 'Tagging "{}" as "{}" by "{}"!'.format(
mp3[len(out_dir) + 1 :], title, artist mp3[len(out_dir) + 1 :], title, artist
@ -829,7 +894,7 @@ class DelojzaBot:
+ "???" + "???"
) )
def tg_delete(self, _, update): def tg_delete(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
if self.last_downloaded.get(update.message.chat.id) is not None: if self.last_downloaded.get(update.message.chat.id) is not None:
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id] files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
@ -865,7 +930,7 @@ class DelojzaBot:
return return
update.message.reply_text("Nothing to remove!") update.message.reply_text("Nothing to remove!")
def tg_protect(self, _, update): def tg_protect(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
self.db.initialize() self.db.initialize()
@ -929,7 +994,7 @@ class DelojzaBot:
+ "???" + "???"
) )
def tg_queue(self, _, update): def tg_queue(self, _, update: telegram.Update):
if self.tumblr_client: if self.tumblr_client:
blog_info = self.tumblr_client.blog_info(self.tumblr_name) blog_info = self.tumblr_client.blog_info(self.tumblr_name)
update.message.reply_text( update.message.reply_text(
@ -943,7 +1008,7 @@ class DelojzaBot:
) )
# noinspection PyMethodMayBeStatic # noinspection PyMethodMayBeStatic
def tg_version(self, _, update): def tg_version(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
delojza_date = datetime.fromtimestamp( delojza_date = datetime.fromtimestamp(
os.path.getmtime(os.path.realpath(__file__)) os.path.getmtime(os.path.realpath(__file__))
@ -954,13 +1019,15 @@ class DelojzaBot:
) )
) )
def tg_start(self, _, update): def tg_start(self, _, update: telegram.Update):
self._log_msg(update) self._log_msg(update)
update.message.reply_text( update.message.reply_text(
self.markov.make_sentence() if self.markov else "HELLO" self.markov.make_sentence() if self.markov else "HELLO"
) )
def tg_error(self, bot, update, error): def tg_error(
self, bot: telegram.Bot, update: telegram.Update, error: telegram.TelegramError
):
self.logger.error(error) self.logger.error(error)
if "Timed out" in str(error): if "Timed out" in str(error):
if update is not None: if update is not None:
@ -980,27 +1047,6 @@ class DelojzaBot:
self.updater.idle() self.updater.idle()
class MarkovBlabberer:
def __init__(self, filepath):
self.logger = logging.getLogger("markov")
self.filepath = filepath
with open(filepath) as f:
text = f.read()
self.markov = markovify.NewlineText(text.lower())
self.logger.info("Sentence of the day: " + self.make_sentence())
def make_sentence(self, tries=100):
return self.markov.make_sentence(tries=tries) or "???"
def add_to_corpus(self, text):
text = text.lower()
new_sentence = markovify.NewlineText(text)
self.markov = markovify.combine([self.markov, new_sentence])
with open(self.filepath, "a") as f:
f.write(text + "\n")
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -1033,9 +1079,19 @@ if __name__ == "__main__":
markov = None markov = None
try: try:
redirects = config.items("redirects") redirects: Optional[List[Tuple[str, str]]] = config.items("redirects")
except NoSectionError: except NoSectionError:
redirects = {} redirects = None
try:
tumblr_keys = (
config.get("tumblr", "consumer_key"),
config.get("tumblr", "consumer_secret"),
config.get("tumblr", "oauth_key"),
config.get("tumblr", "oauth_secret"),
)
except (NoSectionError, KeyError):
tumblr_keys = None
delojza = DelojzaBot( delojza = DelojzaBot(
config.get("delojza", "tg_api_key"), config.get("delojza", "tg_api_key"),
@ -1045,12 +1101,7 @@ if __name__ == "__main__":
protected_password=config.get("delojza", "protected_password", fallback=None), protected_password=config.get("delojza", "protected_password", fallback=None),
acoustid_key=config.get("delojza", "acoustid_api_key", fallback=None), acoustid_key=config.get("delojza", "acoustid_api_key", fallback=None),
tumblr_name=config.get("tumblr", "blog_name", fallback=None), tumblr_name=config.get("tumblr", "blog_name", fallback=None),
tumblr_keys=( tumblr_keys=tumblr_keys,
config.get("tumblr", "consumer_key", fallback=None),
config.get("tumblr", "consumer_secret", fallback=None),
config.get("tumblr", "oauth_key", fallback=None),
config.get("tumblr", "oauth_secret", fallback=None),
),
markov=markov, markov=markov,
) )
delojza.run_idle() delojza.run_idle()