From 7b6e15cae63b7b7d0e58b6bffce2779eff2393dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Ml=C3=A1dek?= Date: Tue, 25 Jun 2019 15:49:25 +0200 Subject: [PATCH] add db, set protected chats/tags dynamically --- db_versions/00_initial.sql | 12 +++++ delojza.py | 101 +++++++++++++++++++++++++++++++++---- 2 files changed, 104 insertions(+), 9 deletions(-) create mode 100644 db_versions/00_initial.sql diff --git a/db_versions/00_initial.sql b/db_versions/00_initial.sql new file mode 100644 index 0000000..a1fc1ec --- /dev/null +++ b/db_versions/00_initial.sql @@ -0,0 +1,12 @@ +CREATE TABLE chats +( + id INTEGER PRIMARY KEY NOT NULL, + protected BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE TABLE tags +( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + tag TEXT UNIQUE NOT NULL, + protected BOOLEAN NOT NULL DEFAULT FALSE +) \ No newline at end of file diff --git a/delojza.py b/delojza.py index 0b9504e..2c58a8e 100755 --- a/delojza.py +++ b/delojza.py @@ -6,6 +6,7 @@ import os import pprint import re import shutil +import sqlite3 import subprocess import sys import tempfile @@ -44,12 +45,53 @@ def datestr(date): return date.strftime("%Y-%m-%d@%H%M") +class DelojzaDB: + def __init__(self, db_path): + self.db_path = db_path + self.db = None + + def initialize(self): + if self.db is None: + self.db = sqlite3.connect(self.db_path) + + def get_protected_tags(self): + results = self.db.execute("SELECT tag FROM tags WHERE protected == 1") + return [res[0] for res in results.fetchall()] + + def get_protected_chats(self): + results = self.db.execute("SELECT id FROM chats WHERE protected == 1") + return [res[0] for res in results.fetchall()] + + def get_chat(self, id): + return self.db.execute("SELECT id, protected FROM chats WHERE id == ?", (id,)).fetchone() + + def set_chat_protected(self, id, protected): + chat_in_db = self.get_chat(id) + if chat_in_db: + self.db.execute("UPDATE chats SET protected = ? WHERE id = ?", (protected, id)) + else: + self.db.execute("INSERT INTO chats (id, protected) VALUES (?, ?)", (id, protected)) + self.db.commit() + + def get_tag(self, tag): + return self.db.execute("SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,)).fetchone() + + def set_tag_protected(self, tag, protected): + tag_in_db = self.get_tag(tag) + if tag_in_db: + self.db.execute("UPDATE tags SET protected = ? WHERE tag = ?", (protected, tag)) + else: + self.db.execute("INSERT INTO tags (tag, protected) VALUES (?, ?)", (tag, protected)) + self.db.commit() + + class DelojzaBot: - def __init__(self, tg_api_key, out_dir, tmp_dir=None, - protected_chats=None, protected_tags=None, + def __init__(self, tg_api_key, out_dir, tmp_dir=None, db_path=None, protected_password=None, acoustid_key=None, tumblr_name=None, tumblr_keys=None, markov=None): self.logger = logging.getLogger("delojza") + self.db = DelojzaDB(db_path or os.path.join(os.path.dirname(os.path.realpath(__file__)), "delojza.db")) + self.out_dir = os.path.abspath(out_dir) self.logger.debug('OUT_DIR: ' + out_dir) self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir() @@ -66,6 +108,7 @@ class DelojzaBot: dp.add_handler(CommandHandler("orphans_full", self.tg_orphan_full)) dp.add_handler(CommandHandler("retag", self.tg_retag)) dp.add_handler(CommandHandler("delete", self.tg_delete)) + dp.add_handler(CommandHandler("protect", self.tg_protect)) dp.add_handler(CommandHandler("version", self.tg_version)) dp.add_handler(MessageHandler(None, self.tg_handle)) @@ -77,9 +120,7 @@ class DelojzaBot: else: self.tumblr_client = None - self.protected_chats = protected_chats or [] - self.protected_tags = protected_tags or [] - + self.protected_password = protected_password self.last_downloaded = {} self.last_hashtags = {} @@ -299,13 +340,15 @@ class DelojzaBot: # noinspection PyBroadException def handle(self, urls, message, hashtags, download_fn, filetitle=None): + self.db.initialize() + try: if len(hashtags) == 0: self.logger.info("Ignoring %s due to no hashtag present..." % urls) return False - if any(hashtag in self.protected_tags for hashtag in hashtags): - if message.chat.title not in self.protected_chats: + if any(hashtag in self.db.get_protected_tags() for hashtag in hashtags): + if message.chat.id not in self.db.get_protected_chats(): self.logger.info("Redirecting {} in chat {} due to protected hashtags: {}..." .format(urls, message.chat.title, hashtags)) hashtags.insert(0, "PUBLIC") @@ -529,6 +572,47 @@ class DelojzaBot: return update.message.reply_text("Nothing to remove!") + def tg_protect(self, _, update): + self.db.initialize() + + msg_split = update.message.text.split(" ") + if len(msg_split) != 3: + update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "") + "???") + return + + chat_in_db = self.db.get_chat(update.message.chat.id) + + cmd = msg_split[1] + if cmd == 'tag': + if chat_in_db and chat_in_db[1]: + tag = msg_split[2].upper() + tag_in_db = self.db.get_tag(tag) + if tag_in_db: + _, _, protected = tag_in_db + end_protected = not protected + else: + end_protected = True + + self.db.set_tag_protected(tag, end_protected) + update.message.reply_text(f"got it, will {'NOT ' if not end_protected else ''}protect tag {tag}!") + else: + update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "hublubl")) + elif cmd == 'chat': + password = msg_split[2] + if password == self.protected_password: + if chat_in_db: + _, protected = chat_in_db + end_protected = not protected + else: + end_protected = True + + self.db.set_chat_protected(update.message.chat.id, end_protected) + update.message.reply_text(f"got it, will {'NOT ' if not end_protected else ''}protect this chat!") + else: + update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "hublubl")) + else: + update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "") + "???") + # noinspection PyMethodMayBeStatic def tg_version(self, _, update): delojza_date = datetime.fromtimestamp(os.path.getmtime(os.path.realpath(__file__))) \ @@ -602,8 +686,7 @@ if __name__ == '__main__': delojza = DelojzaBot(config.get('delojza', 'tg_api_key'), config.get('delojza', 'OUT_DIR', fallback=os.path.join(_DIR_, "out")), tmp_dir=config.get('delojza', 'tmp_dir', fallback=tempfile.gettempdir()), - protected_chats=config.get('delojza', 'protected_chats', fallback='').split(";"), - protected_tags=config.get('delojza', 'protected_tags', fallback='').split(";"), + protected_password=config.get('delojza', 'protected_password', fallback=None), acoustid_key=config.get('delojza', 'acoustid_api_key', fallback=None), tumblr_name=config.get('tumblr', 'blog_name', fallback=None), tumblr_keys=(config.get('tumblr', 'consumer_key', fallback=None),