upend/tools/fromksx/fromksx.py

135 lines
4.5 KiB
Python

import hashlib
import logging
from dataclasses import dataclass
from operator import add
import click
import colorama
import psycopg2
from tqdm import tqdm
from upend import UpEnd
class LogFormatter(logging.Formatter):
format_str = "[%(asctime)s] %(levelname)s - %(message)s"
FORMATS = {
logging.DEBUG: colorama.Fore.LIGHTBLACK_EX + format_str + colorama.Fore.RESET,
logging.INFO: format_str,
logging.WARNING: colorama.Fore.YELLOW + format_str + colorama.Fore.RESET,
logging.ERROR: colorama.Fore.RED + format_str + colorama.Fore.RESET,
logging.CRITICAL: colorama.Fore.RED
+ colorama.Style.BRIGHT
+ format_str
+ colorama.Style.RESET_ALL
+ colorama.Fore.RESET,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
@dataclass
class KSXTrackFile:
file: str
sha256sum: str
energy: int
seriousness: int
tint: int
materials: int
@click.command()
@click.option("--db-name", required=True)
@click.option("--db-user", required=True)
@click.option("--db-password", required=True)
@click.option("--db-host", default="localhost")
@click.option("--db-port", default=5432, type=int)
def main(db_name, db_user, db_password, db_host, db_port):
"""Load KSX database dump into UpEnd."""
logger = logging.getLogger("ksx2upend")
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(LogFormatter())
logger.addHandler(ch)
logger.debug("Connecting to PostgreSQL...")
connection = psycopg2.connect(
database=db_name,
user=db_user,
password=db_password,
host=db_host,
port=db_port,
)
cur = connection.cursor()
logger.debug("Connecting to UpEnd...")
upend = UpEnd()
cur.execute(
"SELECT file, sha256sum, energy, seriousness, tint, materials "
"FROM ksx_radio_trackfile "
"INNER JOIN ksx_radio_moodsregular ON ksx_radio_trackfile.track_id = ksx_radio_moodsregular.track_id"
)
trackfiles = [KSXTrackFile(*row) for row in cur.fetchall()]
logger.info(f"Got {len(trackfiles)} (annotated) trackfiles from database...")
# TODO: get_invariant() or somesuch?
blob_addr = list(upend.query((None, "TYPE", 'J"BLOB"')).values())[0]["entity"]
all_files = upend.query((None, "IS", f"O{blob_addr}")).values()
hashed_files = upend.query((None, "SHA256", None)).values()
logger.info(
f"Got {len(all_files)} files from UpEnd ({len(hashed_files)} of which are hashed)..."
)
if len(hashed_files) < len(all_files):
logger.info("Computing SHA256 hashes for UpEnd files...")
hashed_entries = [entry["entity"] for entry in hashed_files]
unhashed_files = [
file for file in all_files if file["entity"] not in hashed_entries
]
for entry in tqdm(unhashed_files):
sha256_hash = hashlib.sha256()
for chunk in upend.get_raw(entry["entity"]):
sha256_hash.update(chunk)
upend.insert((entry["entity"], "SHA256", sha256_hash.hexdigest()))
hashed_files = upend.query((None, "SHA256", None)).values()
sha256_trackfiles = {tf.sha256sum: tf for tf in trackfiles}
sha256_entities = {entry["value"]["c"]: entry["entity"] for entry in hashed_files}
tf_and_ue = [sum for sum in sha256_trackfiles.keys() if sum in sha256_entities]
logger.info(
f"Out of {len(trackfiles)} trackfiles, and out of {len(hashed_files)} files in UpEnd, {len(tf_and_ue)} are present in both."
)
logger.info("Inserting types...")
ksx_type_result = upend.insert((None, "TYPE", "KSX_TRACK_MOODS"))
ksx_type_addr = list(ksx_type_result.values())[0]["entity"]
upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_ENERGY"))
upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_SERIOUSNESS"))
upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_TINT"))
upend.insert((ksx_type_addr, "TYPE_REQUIRES", "KSX_MATERIALS"))
logger.info("Inserting mood data...")
for sum in tqdm(tf_and_ue):
tf = sha256_trackfiles[sum]
address = sha256_entities[sum]
upend.insert((address, "IS", ksx_type_addr), value_type="Address")
upend.insert((address, "KSX_ENERGY", tf.energy))
upend.insert((address, "KSX_SERIOUSNESS", tf.seriousness))
upend.insert((address, "KSX_TINT", tf.tint))
upend.insert((address, "KSX_MATERIALS", tf.materials))
if __name__ == "__main__":
main()