add db, set protected chats/tags dynamically
This commit is contained in:
parent
bfcecc7cf1
commit
7b6e15cae6
2 changed files with 104 additions and 9 deletions
12
db_versions/00_initial.sql
Normal file
12
db_versions/00_initial.sql
Normal file
|
@ -0,0 +1,12 @@
|
|||
CREATE TABLE chats
|
||||
(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
protected BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE TABLE tags
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
tag TEXT UNIQUE NOT NULL,
|
||||
protected BOOLEAN NOT NULL DEFAULT FALSE
|
||||
)
|
101
delojza.py
101
delojza.py
|
@ -6,6 +6,7 @@ import os
|
|||
import pprint
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -44,12 +45,53 @@ def datestr(date):
|
|||
return date.strftime("%Y-%m-%d@%H%M")
|
||||
|
||||
|
||||
class DelojzaDB:
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
self.db = None
|
||||
|
||||
def initialize(self):
|
||||
if self.db is None:
|
||||
self.db = sqlite3.connect(self.db_path)
|
||||
|
||||
def get_protected_tags(self):
|
||||
results = self.db.execute("SELECT tag FROM tags WHERE protected == 1")
|
||||
return [res[0] for res in results.fetchall()]
|
||||
|
||||
def get_protected_chats(self):
|
||||
results = self.db.execute("SELECT id FROM chats WHERE protected == 1")
|
||||
return [res[0] for res in results.fetchall()]
|
||||
|
||||
def get_chat(self, id):
|
||||
return self.db.execute("SELECT id, protected FROM chats WHERE id == ?", (id,)).fetchone()
|
||||
|
||||
def set_chat_protected(self, id, protected):
|
||||
chat_in_db = self.get_chat(id)
|
||||
if chat_in_db:
|
||||
self.db.execute("UPDATE chats SET protected = ? WHERE id = ?", (protected, id))
|
||||
else:
|
||||
self.db.execute("INSERT INTO chats (id, protected) VALUES (?, ?)", (id, protected))
|
||||
self.db.commit()
|
||||
|
||||
def get_tag(self, tag):
|
||||
return self.db.execute("SELECT id, tag, protected FROM tags WHERE tag == ?", (tag,)).fetchone()
|
||||
|
||||
def set_tag_protected(self, tag, protected):
|
||||
tag_in_db = self.get_tag(tag)
|
||||
if tag_in_db:
|
||||
self.db.execute("UPDATE tags SET protected = ? WHERE tag = ?", (protected, tag))
|
||||
else:
|
||||
self.db.execute("INSERT INTO tags (tag, protected) VALUES (?, ?)", (tag, protected))
|
||||
self.db.commit()
|
||||
|
||||
|
||||
class DelojzaBot:
|
||||
def __init__(self, tg_api_key, out_dir, tmp_dir=None,
|
||||
protected_chats=None, protected_tags=None,
|
||||
def __init__(self, tg_api_key, out_dir, tmp_dir=None, db_path=None, protected_password=None,
|
||||
acoustid_key=None, tumblr_name=None, tumblr_keys=None, markov=None):
|
||||
self.logger = logging.getLogger("delojza")
|
||||
|
||||
self.db = DelojzaDB(db_path or os.path.join(os.path.dirname(os.path.realpath(__file__)), "delojza.db"))
|
||||
|
||||
self.out_dir = os.path.abspath(out_dir)
|
||||
self.logger.debug('OUT_DIR: ' + out_dir)
|
||||
self.tmp_dir = tmp_dir if tmp_dir else tempfile.gettempdir()
|
||||
|
@ -66,6 +108,7 @@ class DelojzaBot:
|
|||
dp.add_handler(CommandHandler("orphans_full", self.tg_orphan_full))
|
||||
dp.add_handler(CommandHandler("retag", self.tg_retag))
|
||||
dp.add_handler(CommandHandler("delete", self.tg_delete))
|
||||
dp.add_handler(CommandHandler("protect", self.tg_protect))
|
||||
dp.add_handler(CommandHandler("version", self.tg_version))
|
||||
dp.add_handler(MessageHandler(None, self.tg_handle))
|
||||
|
||||
|
@ -77,9 +120,7 @@ class DelojzaBot:
|
|||
else:
|
||||
self.tumblr_client = None
|
||||
|
||||
self.protected_chats = protected_chats or []
|
||||
self.protected_tags = protected_tags or []
|
||||
|
||||
self.protected_password = protected_password
|
||||
self.last_downloaded = {}
|
||||
self.last_hashtags = {}
|
||||
|
||||
|
@ -299,13 +340,15 @@ class DelojzaBot:
|
|||
|
||||
# noinspection PyBroadException
|
||||
def handle(self, urls, message, hashtags, download_fn, filetitle=None):
|
||||
self.db.initialize()
|
||||
|
||||
try:
|
||||
if len(hashtags) == 0:
|
||||
self.logger.info("Ignoring %s due to no hashtag present..." % urls)
|
||||
return False
|
||||
|
||||
if any(hashtag in self.protected_tags for hashtag in hashtags):
|
||||
if message.chat.title not in self.protected_chats:
|
||||
if any(hashtag in self.db.get_protected_tags() for hashtag in hashtags):
|
||||
if message.chat.id not in self.db.get_protected_chats():
|
||||
self.logger.info("Redirecting {} in chat {} due to protected hashtags: {}..."
|
||||
.format(urls, message.chat.title, hashtags))
|
||||
hashtags.insert(0, "PUBLIC")
|
||||
|
@ -529,6 +572,47 @@ class DelojzaBot:
|
|||
return
|
||||
update.message.reply_text("Nothing to remove!")
|
||||
|
||||
def tg_protect(self, _, update):
|
||||
self.db.initialize()
|
||||
|
||||
msg_split = update.message.text.split(" ")
|
||||
if len(msg_split) != 3:
|
||||
update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "") + "???")
|
||||
return
|
||||
|
||||
chat_in_db = self.db.get_chat(update.message.chat.id)
|
||||
|
||||
cmd = msg_split[1]
|
||||
if cmd == 'tag':
|
||||
if chat_in_db and chat_in_db[1]:
|
||||
tag = msg_split[2].upper()
|
||||
tag_in_db = self.db.get_tag(tag)
|
||||
if tag_in_db:
|
||||
_, _, protected = tag_in_db
|
||||
end_protected = not protected
|
||||
else:
|
||||
end_protected = True
|
||||
|
||||
self.db.set_tag_protected(tag, end_protected)
|
||||
update.message.reply_text(f"got it, will {'NOT ' if not end_protected else ''}protect tag {tag}!")
|
||||
else:
|
||||
update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "hublubl"))
|
||||
elif cmd == 'chat':
|
||||
password = msg_split[2]
|
||||
if password == self.protected_password:
|
||||
if chat_in_db:
|
||||
_, protected = chat_in_db
|
||||
end_protected = not protected
|
||||
else:
|
||||
end_protected = True
|
||||
|
||||
self.db.set_chat_protected(update.message.chat.id, end_protected)
|
||||
update.message.reply_text(f"got it, will {'NOT ' if not end_protected else ''}protect this chat!")
|
||||
else:
|
||||
update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "hublubl"))
|
||||
else:
|
||||
update.message.reply_text((self.markov.make_sentence() if self.markov and random() > .7 else "") + "???")
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
def tg_version(self, _, update):
|
||||
delojza_date = datetime.fromtimestamp(os.path.getmtime(os.path.realpath(__file__))) \
|
||||
|
@ -602,8 +686,7 @@ if __name__ == '__main__':
|
|||
delojza = DelojzaBot(config.get('delojza', 'tg_api_key'),
|
||||
config.get('delojza', 'OUT_DIR', fallback=os.path.join(_DIR_, "out")),
|
||||
tmp_dir=config.get('delojza', 'tmp_dir', fallback=tempfile.gettempdir()),
|
||||
protected_chats=config.get('delojza', 'protected_chats', fallback='').split(";"),
|
||||
protected_tags=config.get('delojza', 'protected_tags', fallback='').split(";"),
|
||||
protected_password=config.get('delojza', 'protected_password', fallback=None),
|
||||
acoustid_key=config.get('delojza', 'acoustid_api_key', fallback=None),
|
||||
tumblr_name=config.get('tumblr', 'blog_name', fallback=None),
|
||||
tumblr_keys=(config.get('tumblr', 'consumer_key', fallback=None),
|
||||
|
|
Loading…
Reference in a new issue