diff --git a/delojza.py b/delojza.py index 934f457..06a65cc 100755 --- a/delojza.py +++ b/delojza.py @@ -16,7 +16,9 @@ from datetime import datetime, timedelta from glob import glob from operator import itemgetter from random import random +from sqlite3.dbapi2 import Connection from time import sleep +from typing import Any, List, Optional, Tuple, cast import acoustid import filetype @@ -26,8 +28,12 @@ import pytumblr import requests import telegram import youtube_dl +from _typeshed import StrPath +from markovify.text import Text +from mutagen import File, FileType from mutagen.easyid3 import EasyID3 -from telegram.ext import Updater, CommandHandler, MessageHandler +from requests.api import options +from telegram.ext import CommandHandler, MessageHandler, Updater from youtube_dl import DownloadError from youtube_dl.version import __version__ as YTDL_VERSION @@ -42,33 +48,41 @@ def mkdir_p(path): raise -def datestr(date): +def datestr(date: datetime): return date.strftime("%Y-%m-%d@%H%M") class DelojzaDB: def __init__(self, db_path): self.db_path = db_path - self.db = None + self.db: Optional[Connection] = None def initialize(self): if self.db is None: self.db = sqlite3.connect(self.db_path) def get_protected_tags(self): + if self.db is None: + raise RuntimeError("Database not initialized!") results = self.db.execute("SELECT tag FROM tags WHERE protected == 1") return [res[0] for res in results.fetchall()] def get_protected_chats(self): + if self.db is None: + raise RuntimeError("Database not initialized!") results = self.db.execute("SELECT id FROM chats WHERE protected == 1") return [res[0] for res in results.fetchall()] def get_chat(self, id): + if self.db is None: + raise RuntimeError("Database not initialized!") return self.db.execute( "SELECT id, protected FROM chats WHERE id == ?", (id,) ).fetchone() - def set_chat_protected(self, id, protected): + def set_chat_protected(self, id: str, protected: bool): + if self.db is None: + raise RuntimeError("Database not initialized!") chat_in_db = self.get_chat(id) if chat_in_db: self.db.execute( @@ -80,12 +94,16 @@ class DelojzaDB: ) self.db.commit() - def get_tag(self, tag): + def get_tag(self, tag: str): + if self.db is None: + raise RuntimeError("Database not initialized!") return self.db.execute( "SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,) ).fetchone() - def set_tag_protected(self, tag, protected): + def set_tag_protected(self, tag: str, protected: bool): + if self.db is None: + raise RuntimeError("Database not initialized!") tag_in_db = self.get_tag(tag) if tag_in_db: self.db.execute( @@ -98,19 +116,40 @@ class DelojzaDB: self.db.commit() +class MarkovBlabberer: + def __init__(self, filepath: StrPath): + self.logger = logging.getLogger("markov") + self.filepath = filepath + + with open(filepath) as f: + text = f.read() + self.markov: Text = markovify.NewlineText(text.lower()) + self.logger.info("Sentence of the day: " + self.make_sentence()) + + def make_sentence(self, tries: int = 100): + return self.markov.make_sentence(tries=tries) or "???" + + def add_to_corpus(self, text: str): + text = text.lower() + new_sentence = markovify.NewlineText(text) + self.markov = cast(Text, markovify.combine([self.markov, new_sentence])) + with open(self.filepath, "a") as f: + f.write(text + "\n") + + class DelojzaBot: def __init__( self, - tg_api_key, - out_dir, - redirects=None, - tmp_dir=None, - db_path=None, - protected_password=None, - acoustid_key=None, - tumblr_name=None, - tumblr_keys=None, - markov=None, + tg_api_key: str, + out_dir: StrPath, + redirects: Optional[List[Tuple[str, str]]] = None, + tmp_dir: Optional[StrPath] = None, + db_path: Optional[StrPath] = None, + protected_password: Optional[str] = None, + acoustid_key: Optional[str] = None, + tumblr_name: Optional[str] = None, + tumblr_keys: Optional[Tuple[str, str, str, str]] = None, + markov: Optional[MarkovBlabberer] = None, ): self._setup_logging(os.path.dirname(os.path.realpath(__file__))) @@ -121,15 +160,16 @@ class DelojzaBot: self.out_dir = os.path.abspath(out_dir) self.out_dir = self.out_dir[:-1] if self.out_dir[-1] == "/" else self.out_dir - self.logger.debug("OUT_DIR: " + out_dir) + self.logger.debug(f"OUT_DIR: {out_dir}") self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir() - self.logger.debug("TMP_DIR: " + tmp_dir) + self.logger.debug(f"TMP_DIR: {tmp_dir}") self.markov = markov self.redirects = {} if redirects is not None: for hashtag, directory in redirects: hashtag = hashtag.upper() + directory = str(directory) directory = directory[:-1] if directory[-1] == "/" else directory mkdir_p(directory) self.redirects[hashtag] = directory @@ -162,14 +202,14 @@ class DelojzaBot: self.last_downloaded = {} self.last_hashtags = {} - def _setup_logging(self, log_path): + def _setup_logging(self, log_path: StrPath): self.logger = logging.getLogger("delojza") self.logger.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.INFO) - dfh = logging.FileHandler(log_path + "/delojza.log") + dfh = logging.FileHandler(os.path.join(log_path, "delojza.log")) dfh.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -190,7 +230,7 @@ class DelojzaBot: ) @staticmethod - def ytdl_can(url): + def ytdl_can(url: str): ies = youtube_dl.extractor.gen_extractors() for ie in ies: if ie.suitable(url) and ie.IE_NAME != "generic" and "/channel/" not in url: @@ -200,18 +240,18 @@ class DelojzaBot: # https://github.com/django/django/blob/master/django/utils/text.py#L393 @staticmethod - def sanitize(filepath): - if filepath is None: - return None - filepath = ( - unicodedata.normalize("NFKD", filepath) + def sanitize(text: str): + if text is None: + return "" + text = ( + unicodedata.normalize("NFKD", text) .encode("ascii", "ignore") .decode("ascii") ) - return re.sub(r"[^\w.()\[\]{}#-]", "_", filepath) + return re.sub(r"[^\w.()\[\]{}#-]", "_", text) @staticmethod - def _get_tags(filepath): + def _get_tags(filepath: StrPath): try: audio = EasyID3(filepath) return ( @@ -222,11 +262,13 @@ class DelojzaBot: return None, None @staticmethod - def _tag_file(filepath, artist, title): + def _tag_file(filepath: StrPath, artist: Optional[str], title: str): try: id3 = mutagen.id3.ID3(filepath) except mutagen.id3.ID3NoHeaderError: - mutafile = mutagen.File(filepath) + mutafile = cast(Optional[FileType], File(filepath)) + if not mutafile: + return mutafile.add_tags() mutafile.save() id3 = mutagen.id3.ID3(filepath) @@ -258,7 +300,7 @@ class DelojzaBot: score, rid, aid_title, aid_artist = results[0] if score > 0.4: title = aid_title - artist = re.sub(r" *; +", " & ", aid_artist) + artist = re.sub(r" *; +", " & ", aid_artist or "") best_acoustid_score = score source = "AcoustID ({}%)".format(round(score * 100)) except acoustid.NoBackendError: @@ -296,7 +338,7 @@ class DelojzaBot: artist = artist.strip() if artist else None title = title.strip() if title else None - if title is None and artist is None: + if title is None: message.reply_text("Tried tagging, found nothing :(") return @@ -309,7 +351,7 @@ class DelojzaBot: self._tag_file(filepath, artist, title) @staticmethod - def _get_percent_filled(directory): + def _get_percent_filled(directory: str): output = subprocess.check_output(["df", directory]) percents_re = re.search(r"[0-9]+%", output.decode("utf-8")) if not percents_re: @@ -317,7 +359,15 @@ class DelojzaBot: return int(percents_re.group(0)[:-1]) # noinspection PyUnusedLocal - def download_ytdl(self, urls, out_path, date, message, audio=False, filetitle=None): + def download_ytdl( + self, + urls: List[str], + out_path: StrPath, + date: datetime, + message: telegram.Message, + audio: bool = False, + filetitle: Optional[str] = None, + ): ytdl = { "noplaylist": True, "restrictfilenames": True, @@ -352,7 +402,7 @@ class DelojzaBot: else: raise exc for info in [ytdl.extract_info(url, download=False) for url in urls]: - filename = ytdl.prepare_filename(info) + filename = cast(str, ytdl.prepare_filename(info)) globbeds = glob(os.path.splitext(filename)[0] + ".*") for globbed in globbeds: if globbed.endswith("mp3"): @@ -362,7 +412,15 @@ class DelojzaBot: filenames.append(dest) return filenames - def download_raw(self, urls, out_path, date, message, audio=False, filetitle=None): + def download_raw( + self, + urls: List[str], + out_path: StrPath, + date: datetime, + message: telegram.Message, + audio: bool = False, + filetitle: Optional[str] = None, + ): filenames = [] for url in urls: local_filename = os.path.join( @@ -409,7 +467,7 @@ class DelojzaBot: return filenames @staticmethod - def extract_hashtags(message): + def extract_hashtags(message: telegram.Message): hashtags = list( map( message.parse_entity, @@ -429,7 +487,7 @@ class DelojzaBot: hashtags[i] = "PRAS" return hashtags - def _get_hashtags(self, message): + def _get_hashtags(self, message: telegram.Message): hashtags = self.extract_hashtags(message) if len(hashtags) == 0 and self.last_hashtags.get(message.chat.id) is not None: user, ts, last_hashtags = self.last_hashtags[message.chat.id] @@ -437,7 +495,7 @@ class DelojzaBot: hashtags = last_hashtags return hashtags - def handle_text(self, message, hashtags): + def handle_text(self, message: telegram.Message, hashtags: List[str]): if len(hashtags) == 0 or hashtags[0] not in ("TEXT", "TXT"): return @@ -464,7 +522,14 @@ class DelojzaBot: ) # noinspection PyBroadException - def handle(self, urls, message, hashtags, download_fn, filetitle=None): + def handle( + self, + urls: List[str], + message: telegram.Message, + hashtags: List[str], + download_fn: Any, + filetitle=None, + ): self.db.initialize() try: @@ -605,7 +670,7 @@ class DelojzaBot: else: return False - def handle_urls(self, message, hashtags): + def handle_urls(self, message: telegram.Message, hashtags: List[str]): urls = list( map( lambda e: message.parse_entity(e), @@ -631,7 +696,7 @@ class DelojzaBot: return ytdl_res or raw_res - def tg_handle(self, bot, update): + def tg_handle(self, bot: telegram.Bot, update: telegram.Update): self._log_msg(update) hashtags = self._get_hashtags(update.message) if hashtags: @@ -682,7 +747,7 @@ class DelojzaBot: + list(self.redirects.keys()) ) - def tg_stats(self, _, update): + def tg_stats(self, _, update: telegram.Update): self._log_msg(update) self.db.initialize() if update.message.chat.id not in self.db.get_protected_chats(): @@ -750,7 +815,7 @@ class DelojzaBot: result.append((directory, "NO FILE AT ALL...")) return sorted(result, key=itemgetter(0)) - def tg_orphan(self, _, update): + def tg_orphan(self, _, update: telegram.Update): self._log_msg(update) self.db.initialize() if update.message.chat.id not in self.db.get_protected_chats(): @@ -793,7 +858,7 @@ class DelojzaBot: if len(tmp_reply) > 0: update.message.reply_text(tmp_reply) - def tg_retag(self, _, update): + def tg_retag(self, _, update: telegram.Update): self._log_msg(update) if self.last_downloaded.get(update.message.chat.id) is not None: files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id] @@ -817,7 +882,7 @@ class DelojzaBot: orig_artist, orig_title = self._get_tags(mp3) title, artist = orig_artist, orig_title - self._tag_file(mp3, artist, title) + self._tag_file(mp3, artist, cast(str, title)) update.message.reply_text( 'Tagging "{}" as "{}" by "{}"!'.format( mp3[len(out_dir) + 1 :], title, artist @@ -829,7 +894,7 @@ class DelojzaBot: + "???" ) - def tg_delete(self, _, update): + def tg_delete(self, _, update: telegram.Update): self._log_msg(update) if self.last_downloaded.get(update.message.chat.id) is not None: files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id] @@ -865,7 +930,7 @@ class DelojzaBot: return update.message.reply_text("Nothing to remove!") - def tg_protect(self, _, update): + def tg_protect(self, _, update: telegram.Update): self._log_msg(update) self.db.initialize() @@ -929,7 +994,7 @@ class DelojzaBot: + "???" ) - def tg_queue(self, _, update): + def tg_queue(self, _, update: telegram.Update): if self.tumblr_client: blog_info = self.tumblr_client.blog_info(self.tumblr_name) update.message.reply_text( @@ -943,7 +1008,7 @@ class DelojzaBot: ) # noinspection PyMethodMayBeStatic - def tg_version(self, _, update): + def tg_version(self, _, update: telegram.Update): self._log_msg(update) delojza_date = datetime.fromtimestamp( os.path.getmtime(os.path.realpath(__file__)) @@ -954,13 +1019,15 @@ class DelojzaBot: ) ) - def tg_start(self, _, update): + def tg_start(self, _, update: telegram.Update): self._log_msg(update) update.message.reply_text( self.markov.make_sentence() if self.markov else "HELLO" ) - def tg_error(self, bot, update, error): + def tg_error( + self, bot: telegram.Bot, update: telegram.Update, error: telegram.TelegramError + ): self.logger.error(error) if "Timed out" in str(error): if update is not None: @@ -980,27 +1047,6 @@ class DelojzaBot: self.updater.idle() -class MarkovBlabberer: - def __init__(self, filepath): - self.logger = logging.getLogger("markov") - self.filepath = filepath - - with open(filepath) as f: - text = f.read() - self.markov = markovify.NewlineText(text.lower()) - self.logger.info("Sentence of the day: " + self.make_sentence()) - - def make_sentence(self, tries=100): - return self.markov.make_sentence(tries=tries) or "???" - - def add_to_corpus(self, text): - text = text.lower() - new_sentence = markovify.NewlineText(text) - self.markov = markovify.combine([self.markov, new_sentence]) - with open(self.filepath, "a") as f: - f.write(text + "\n") - - if __name__ == "__main__": logging.basicConfig( level=logging.INFO, @@ -1033,9 +1079,19 @@ if __name__ == "__main__": markov = None try: - redirects = config.items("redirects") + redirects: Optional[List[Tuple[str, str]]] = config.items("redirects") except NoSectionError: - redirects = {} + redirects = None + + try: + tumblr_keys = ( + config.get("tumblr", "consumer_key"), + config.get("tumblr", "consumer_secret"), + config.get("tumblr", "oauth_key"), + config.get("tumblr", "oauth_secret"), + ) + except (NoSectionError, KeyError): + tumblr_keys = None delojza = DelojzaBot( config.get("delojza", "tg_api_key"), @@ -1045,12 +1101,7 @@ if __name__ == "__main__": protected_password=config.get("delojza", "protected_password", fallback=None), acoustid_key=config.get("delojza", "acoustid_api_key", fallback=None), tumblr_name=config.get("tumblr", "blog_name", fallback=None), - tumblr_keys=( - config.get("tumblr", "consumer_key", fallback=None), - config.get("tumblr", "consumer_secret", fallback=None), - config.get("tumblr", "oauth_key", fallback=None), - config.get("tumblr", "oauth_secret", fallback=None), - ), + tumblr_keys=tumblr_keys, markov=markov, ) delojza.run_idle()