add basic typing

master
Tomáš Mládek 2021-09-14 16:03:55 +02:00
parent 604a1378b7
commit 6b7949e1a1
1 changed files with 132 additions and 81 deletions

View File

@ -16,7 +16,9 @@ from datetime import datetime, timedelta
from glob import glob
from operator import itemgetter
from random import random
from sqlite3.dbapi2 import Connection
from time import sleep
from typing import Any, List, Optional, Tuple, cast
import acoustid
import filetype
@ -26,8 +28,12 @@ import pytumblr
import requests
import telegram
import youtube_dl
from _typeshed import StrPath
from markovify.text import Text
from mutagen import File, FileType
from mutagen.easyid3 import EasyID3
from telegram.ext import Updater, CommandHandler, MessageHandler
from requests.api import options
from telegram.ext import CommandHandler, MessageHandler, Updater
from youtube_dl import DownloadError
from youtube_dl.version import __version__ as YTDL_VERSION
@ -42,33 +48,41 @@ def mkdir_p(path):
raise
def datestr(date):
def datestr(date: datetime):
return date.strftime("%Y-%m-%d@%H%M")
class DelojzaDB:
def __init__(self, db_path):
self.db_path = db_path
self.db = None
self.db: Optional[Connection] = None
def initialize(self):
if self.db is None:
self.db = sqlite3.connect(self.db_path)
def get_protected_tags(self):
if self.db is None:
raise RuntimeError("Database not initialized!")
results = self.db.execute("SELECT tag FROM tags WHERE protected == 1")
return [res[0] for res in results.fetchall()]
def get_protected_chats(self):
if self.db is None:
raise RuntimeError("Database not initialized!")
results = self.db.execute("SELECT id FROM chats WHERE protected == 1")
return [res[0] for res in results.fetchall()]
def get_chat(self, id):
if self.db is None:
raise RuntimeError("Database not initialized!")
return self.db.execute(
"SELECT id, protected FROM chats WHERE id == ?", (id,)
).fetchone()
def set_chat_protected(self, id, protected):
def set_chat_protected(self, id: str, protected: bool):
if self.db is None:
raise RuntimeError("Database not initialized!")
chat_in_db = self.get_chat(id)
if chat_in_db:
self.db.execute(
@ -80,12 +94,16 @@ class DelojzaDB:
)
self.db.commit()
def get_tag(self, tag):
def get_tag(self, tag: str):
if self.db is None:
raise RuntimeError("Database not initialized!")
return self.db.execute(
"SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,)
).fetchone()
def set_tag_protected(self, tag, protected):
def set_tag_protected(self, tag: str, protected: bool):
if self.db is None:
raise RuntimeError("Database not initialized!")
tag_in_db = self.get_tag(tag)
if tag_in_db:
self.db.execute(
@ -98,19 +116,40 @@ class DelojzaDB:
self.db.commit()
class MarkovBlabberer:
def __init__(self, filepath: StrPath):
self.logger = logging.getLogger("markov")
self.filepath = filepath
with open(filepath) as f:
text = f.read()
self.markov: Text = markovify.NewlineText(text.lower())
self.logger.info("Sentence of the day: " + self.make_sentence())
def make_sentence(self, tries: int = 100):
return self.markov.make_sentence(tries=tries) or "???"
def add_to_corpus(self, text: str):
text = text.lower()
new_sentence = markovify.NewlineText(text)
self.markov = cast(Text, markovify.combine([self.markov, new_sentence]))
with open(self.filepath, "a") as f:
f.write(text + "\n")
class DelojzaBot:
def __init__(
self,
tg_api_key,
out_dir,
redirects=None,
tmp_dir=None,
db_path=None,
protected_password=None,
acoustid_key=None,
tumblr_name=None,
tumblr_keys=None,
markov=None,
tg_api_key: str,
out_dir: StrPath,
redirects: Optional[List[Tuple[str, str]]] = None,
tmp_dir: Optional[StrPath] = None,
db_path: Optional[StrPath] = None,
protected_password: Optional[str] = None,
acoustid_key: Optional[str] = None,
tumblr_name: Optional[str] = None,
tumblr_keys: Optional[Tuple[str, str, str, str]] = None,
markov: Optional[MarkovBlabberer] = None,
):
self._setup_logging(os.path.dirname(os.path.realpath(__file__)))
@ -121,15 +160,16 @@ class DelojzaBot:
self.out_dir = os.path.abspath(out_dir)
self.out_dir = self.out_dir[:-1] if self.out_dir[-1] == "/" else self.out_dir
self.logger.debug("OUT_DIR: " + out_dir)
self.logger.debug(f"OUT_DIR: {out_dir}")
self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir()
self.logger.debug("TMP_DIR: " + tmp_dir)
self.logger.debug(f"TMP_DIR: {tmp_dir}")
self.markov = markov
self.redirects = {}
if redirects is not None:
for hashtag, directory in redirects:
hashtag = hashtag.upper()
directory = str(directory)
directory = directory[:-1] if directory[-1] == "/" else directory
mkdir_p(directory)
self.redirects[hashtag] = directory
@ -162,14 +202,14 @@ class DelojzaBot:
self.last_downloaded = {}
self.last_hashtags = {}
def _setup_logging(self, log_path):
def _setup_logging(self, log_path: StrPath):
self.logger = logging.getLogger("delojza")
self.logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
dfh = logging.FileHandler(log_path + "/delojza.log")
dfh = logging.FileHandler(os.path.join(log_path, "delojza.log"))
dfh.setLevel(logging.DEBUG)
formatter = logging.Formatter(
@ -190,7 +230,7 @@ class DelojzaBot:
)
@staticmethod
def ytdl_can(url):
def ytdl_can(url: str):
ies = youtube_dl.extractor.gen_extractors()
for ie in ies:
if ie.suitable(url) and ie.IE_NAME != "generic" and "/channel/" not in url:
@ -200,18 +240,18 @@ class DelojzaBot:
# https://github.com/django/django/blob/master/django/utils/text.py#L393
@staticmethod
def sanitize(filepath):
if filepath is None:
return None
filepath = (
unicodedata.normalize("NFKD", filepath)
def sanitize(text: str):
if text is None:
return ""
text = (
unicodedata.normalize("NFKD", text)
.encode("ascii", "ignore")
.decode("ascii")
)
return re.sub(r"[^\w.()\[\]{}#-]", "_", filepath)
return re.sub(r"[^\w.()\[\]{}#-]", "_", text)
@staticmethod
def _get_tags(filepath):
def _get_tags(filepath: StrPath):
try:
audio = EasyID3(filepath)
return (
@ -222,11 +262,13 @@ class DelojzaBot:
return None, None
@staticmethod
def _tag_file(filepath, artist, title):
def _tag_file(filepath: StrPath, artist: Optional[str], title: str):
try:
id3 = mutagen.id3.ID3(filepath)
except mutagen.id3.ID3NoHeaderError:
mutafile = mutagen.File(filepath)
mutafile = cast(Optional[FileType], File(filepath))
if not mutafile:
return
mutafile.add_tags()
mutafile.save()
id3 = mutagen.id3.ID3(filepath)
@ -258,7 +300,7 @@ class DelojzaBot:
score, rid, aid_title, aid_artist = results[0]
if score > 0.4:
title = aid_title
artist = re.sub(r" *; +", " & ", aid_artist)
artist = re.sub(r" *; +", " & ", aid_artist or "")
best_acoustid_score = score
source = "AcoustID ({}%)".format(round(score * 100))
except acoustid.NoBackendError:
@ -296,7 +338,7 @@ class DelojzaBot:
artist = artist.strip() if artist else None
title = title.strip() if title else None
if title is None and artist is None:
if title is None:
message.reply_text("Tried tagging, found nothing :(")
return
@ -309,7 +351,7 @@ class DelojzaBot:
self._tag_file(filepath, artist, title)
@staticmethod
def _get_percent_filled(directory):
def _get_percent_filled(directory: str):
output = subprocess.check_output(["df", directory])
percents_re = re.search(r"[0-9]+%", output.decode("utf-8"))
if not percents_re:
@ -317,7 +359,15 @@ class DelojzaBot:
return int(percents_re.group(0)[:-1])
# noinspection PyUnusedLocal
def download_ytdl(self, urls, out_path, date, message, audio=False, filetitle=None):
def download_ytdl(
self,
urls: List[str],
out_path: StrPath,
date: datetime,
message: telegram.Message,
audio: bool = False,
filetitle: Optional[str] = None,
):
ytdl = {
"noplaylist": True,
"restrictfilenames": True,
@ -352,7 +402,7 @@ class DelojzaBot:
else:
raise exc
for info in [ytdl.extract_info(url, download=False) for url in urls]:
filename = ytdl.prepare_filename(info)
filename = cast(str, ytdl.prepare_filename(info))
globbeds = glob(os.path.splitext(filename)[0] + ".*")
for globbed in globbeds:
if globbed.endswith("mp3"):
@ -362,7 +412,15 @@ class DelojzaBot:
filenames.append(dest)
return filenames
def download_raw(self, urls, out_path, date, message, audio=False, filetitle=None):
def download_raw(
self,
urls: List[str],
out_path: StrPath,
date: datetime,
message: telegram.Message,
audio: bool = False,
filetitle: Optional[str] = None,
):
filenames = []
for url in urls:
local_filename = os.path.join(
@ -409,7 +467,7 @@ class DelojzaBot:
return filenames
@staticmethod
def extract_hashtags(message):
def extract_hashtags(message: telegram.Message):
hashtags = list(
map(
message.parse_entity,
@ -429,7 +487,7 @@ class DelojzaBot:
hashtags[i] = "PRAS"
return hashtags
def _get_hashtags(self, message):
def _get_hashtags(self, message: telegram.Message):
hashtags = self.extract_hashtags(message)
if len(hashtags) == 0 and self.last_hashtags.get(message.chat.id) is not None:
user, ts, last_hashtags = self.last_hashtags[message.chat.id]
@ -437,7 +495,7 @@ class DelojzaBot:
hashtags = last_hashtags
return hashtags
def handle_text(self, message, hashtags):
def handle_text(self, message: telegram.Message, hashtags: List[str]):
if len(hashtags) == 0 or hashtags[0] not in ("TEXT", "TXT"):
return
@ -464,7 +522,14 @@ class DelojzaBot:
)
# noinspection PyBroadException
def handle(self, urls, message, hashtags, download_fn, filetitle=None):
def handle(
self,
urls: List[str],
message: telegram.Message,
hashtags: List[str],
download_fn: Any,
filetitle=None,
):
self.db.initialize()
try:
@ -605,7 +670,7 @@ class DelojzaBot:
else:
return False
def handle_urls(self, message, hashtags):
def handle_urls(self, message: telegram.Message, hashtags: List[str]):
urls = list(
map(
lambda e: message.parse_entity(e),
@ -631,7 +696,7 @@ class DelojzaBot:
return ytdl_res or raw_res
def tg_handle(self, bot, update):
def tg_handle(self, bot: telegram.Bot, update: telegram.Update):
self._log_msg(update)
hashtags = self._get_hashtags(update.message)
if hashtags:
@ -682,7 +747,7 @@ class DelojzaBot:
+ list(self.redirects.keys())
)
def tg_stats(self, _, update):
def tg_stats(self, _, update: telegram.Update):
self._log_msg(update)
self.db.initialize()
if update.message.chat.id not in self.db.get_protected_chats():
@ -750,7 +815,7 @@ class DelojzaBot:
result.append((directory, "NO FILE AT ALL..."))
return sorted(result, key=itemgetter(0))
def tg_orphan(self, _, update):
def tg_orphan(self, _, update: telegram.Update):
self._log_msg(update)
self.db.initialize()
if update.message.chat.id not in self.db.get_protected_chats():
@ -793,7 +858,7 @@ class DelojzaBot:
if len(tmp_reply) > 0:
update.message.reply_text(tmp_reply)
def tg_retag(self, _, update):
def tg_retag(self, _, update: telegram.Update):
self._log_msg(update)
if self.last_downloaded.get(update.message.chat.id) is not None:
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
@ -817,7 +882,7 @@ class DelojzaBot:
orig_artist, orig_title = self._get_tags(mp3)
title, artist = orig_artist, orig_title
self._tag_file(mp3, artist, title)
self._tag_file(mp3, artist, cast(str, title))
update.message.reply_text(
'Tagging "{}" as "{}" by "{}"!'.format(
mp3[len(out_dir) + 1 :], title, artist
@ -829,7 +894,7 @@ class DelojzaBot:
+ "???"
)
def tg_delete(self, _, update):
def tg_delete(self, _, update: telegram.Update):
self._log_msg(update)
if self.last_downloaded.get(update.message.chat.id) is not None:
files, hashtags, tumblr_ids = self.last_downloaded[update.message.chat.id]
@ -865,7 +930,7 @@ class DelojzaBot:
return
update.message.reply_text("Nothing to remove!")
def tg_protect(self, _, update):
def tg_protect(self, _, update: telegram.Update):
self._log_msg(update)
self.db.initialize()
@ -929,7 +994,7 @@ class DelojzaBot:
+ "???"
)
def tg_queue(self, _, update):
def tg_queue(self, _, update: telegram.Update):
if self.tumblr_client:
blog_info = self.tumblr_client.blog_info(self.tumblr_name)
update.message.reply_text(
@ -943,7 +1008,7 @@ class DelojzaBot:
)
# noinspection PyMethodMayBeStatic
def tg_version(self, _, update):
def tg_version(self, _, update: telegram.Update):
self._log_msg(update)
delojza_date = datetime.fromtimestamp(
os.path.getmtime(os.path.realpath(__file__))
@ -954,13 +1019,15 @@ class DelojzaBot:
)
)
def tg_start(self, _, update):
def tg_start(self, _, update: telegram.Update):
self._log_msg(update)
update.message.reply_text(
self.markov.make_sentence() if self.markov else "HELLO"
)
def tg_error(self, bot, update, error):
def tg_error(
self, bot: telegram.Bot, update: telegram.Update, error: telegram.TelegramError
):
self.logger.error(error)
if "Timed out" in str(error):
if update is not None:
@ -980,27 +1047,6 @@ class DelojzaBot:
self.updater.idle()
class MarkovBlabberer:
def __init__(self, filepath):
self.logger = logging.getLogger("markov")
self.filepath = filepath
with open(filepath) as f:
text = f.read()
self.markov = markovify.NewlineText(text.lower())
self.logger.info("Sentence of the day: " + self.make_sentence())
def make_sentence(self, tries=100):
return self.markov.make_sentence(tries=tries) or "???"
def add_to_corpus(self, text):
text = text.lower()
new_sentence = markovify.NewlineText(text)
self.markov = markovify.combine([self.markov, new_sentence])
with open(self.filepath, "a") as f:
f.write(text + "\n")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
@ -1033,9 +1079,19 @@ if __name__ == "__main__":
markov = None
try:
redirects = config.items("redirects")
redirects: Optional[List[Tuple[str, str]]] = config.items("redirects")
except NoSectionError:
redirects = {}
redirects = None
try:
tumblr_keys = (
config.get("tumblr", "consumer_key"),
config.get("tumblr", "consumer_secret"),
config.get("tumblr", "oauth_key"),
config.get("tumblr", "oauth_secret"),
)
except (NoSectionError, KeyError):
tumblr_keys = None
delojza = DelojzaBot(
config.get("delojza", "tg_api_key"),
@ -1045,12 +1101,7 @@ if __name__ == "__main__":
protected_password=config.get("delojza", "protected_password", fallback=None),
acoustid_key=config.get("delojza", "acoustid_api_key", fallback=None),
tumblr_name=config.get("tumblr", "blog_name", fallback=None),
tumblr_keys=(
config.get("tumblr", "consumer_key", fallback=None),
config.get("tumblr", "consumer_secret", fallback=None),
config.get("tumblr", "oauth_key", fallback=None),
config.get("tumblr", "oauth_secret", fallback=None),
),
tumblr_keys=tumblr_keys,
markov=markov,
)
delojza.run_idle()