#!/usr/bin/env python3
import argparse
import enum
import os
import re
import shlex
import stat
import sys
import nihtest.Command
import nihtest.CompareArrays

table_archive = ">>> table archive (archive_id, name, mtime, size, file_type)"
table_detector = ">>> table detector (detector_id, name, version)"
table_file = ">>> table file (archive_id, file_idx, name, mtime, status, size, crc, md5, sha1, sha256, detector_id)"


class Table(enum.Enum):
    ARCHIVE = enum.auto()
    DETECTOR = enum.auto()
    FILE = enum.auto()


re_attribute = re.compile("^([^=]+)=(.*)")

status_names = {
    "ok": 0,
    "baddump": 1,
    "nodump": 2
}

class Index:
    def __init__(self, start_index = 0):
        self.next_index = start_index
        self.indices = {}
        self.names = {}

    def add(self, name, index):
        if index >= self.next_index:
            self.next_index = index + 1
        self.indices[name] = index
        self.names[index] = name

    def add_files(self, files):
        for file in files:
            self.add(file.name, int(file.index))

    def add_archives(self, archives):
        for archive in archives:
            self.add(archive.name, archive.id)

    def get_index(self, name):
        if name in self.indices:
            return self.indices[name]
        index = self.next_index
        self.add(name, index)
        return index

    def get_existing_index(self, name):
        if name in self.indices:
            return self.indices[name]
        return None

    def get_existing_name(self, index):
        if index in self.names:
            return self.names[index]
        return None

class File:
    def __init__(self, name, index=0, mtime="<null>", status=0, size="-1", crc="<null>", md5="<null>", sha1="<null>", sha256="<null>", detector_id=0):
        self.name = name
        self.index = index
        self.status = status
        self.size = size
        self.crc = crc
        self.md5 = md5
        self.sha1 = sha1
        self.sha256 = sha256
        self.mtime = mtime
        self.detector_id = detector_id

    def __eq__(self, other):
        return (int(self.index) == int(other.index)
                and self.name == other.name
                and self.status == other.status
                and int(self.size) == int(other.size)
                and self.crc == other.crc
                and self.md5 == other.md5
                and self.sha1 == other.sha1
                and self.sha256 == other.sha256
                and int(self.detector_id) == int(other.detector_id))

    def __lt__(self, other):
        if int(self.index) < int(other.index):
            return True
        if int(self.index) > int(other.index):
            return False
        if self.name < other.name:
            return True
        if self.name > other.name:
            return False
        if int(self.mtime) < int(other.mtime):
            return True
        if int(self.mtime) > int(other.mtime):
            return False
        if self.status < other.status:
            return True
        if self.status > other.status:
            return False
        if int(self.size) < int(other.size):
            return True
        if int(self.size) > int(other.size):
            return False
        if self.crc < other.crc:
            return True
        if self.crc > other.crc:
            return False
        if self.md5 < other.md5:
            return True
        if self.md5 > other.md5:
            return False
        if self.sha1 < other.sha1:
            return True
        if self.sha1 > other.sha1:
            return False
        if self.sha256 < other.sha256:
            return True
        if self.sha256 > other.sha256:
            return False
        if int(self.detector_id) < int(other.detector_id):
            return True
        return False

    def row(self, archive_id):
        return f"{archive_id}|{self.index}|{self.name}|{self.mtime}|{self.status}|{self.size}|{self.crc}|{self.md5}|{self.sha1}|{self.sha256}|{self.detector_id}"


class Archive:
    def __init__(self, id, name, mtime="<null>", size=0, file_type=0):
        self.id = int(id)
        self.name = name
        self.mtime = mtime
        self.size = size
        self.files = []
        self.detector_hashes = []
        self.file_type = file_type
        self.file_index = Index()

    def add_file(self, file):
        file.index = self.file_index.get_index(file.name)
        # if one of the files has no CRC, assume the archive is a disk
        if file.crc == "<null>":
            self.file_type = 1
        self.files.append(file)

    def add_detector_hash(self, file):
        self.detector_hashes.append(file)

    def archive_row(self):
        return f"{self.id}|{self.name}|{self.mtime}|{self.size}|{self.file_type}"

    def file_rows(self):
        return map(lambda f: f.row(self.id), sorted(self.files + self.detector_hashes))

    def get_detector_hash(self, file_index, detector_id):
        for file in self.files:
            if int(file.index) == int(file_index) and int(file.detector_id) == int(detector_id):
                return file
        return None

class Detector:
    def __init__(self, id, name, version):
        self.id = int(id)
        self.name = name
        self.version = version

    def row(self):
        return f"{self.id}|{self.name}|{self.version}"


class CkmameDB:
    def __init__(self):
        self.archives = {}
        self.archive_index = Index()
        self.detectors = {}
        self.detector_index = Index(1)

    def add_archive(self, archive):
        self.archives[archive.name] = archive
        self.archive_index.add(archive.name, archive.id)

    def add_detector(self, detector):
        key = detector.name + "|" + detector.version
        self.detectors[key] = detector
        self.detector_index.add(key, detector.id)

    def get_archive_by_id(self, id):
        if name := self.archive_index.get_existing_name(id):
            return self.archives[name]
        else:
            return None

    def get_archive_id(self, name):
        return self.archive_index.get_index(name)

    def get_detector_id(self, key):
        return self.detector_index.get_index(key)

    def dump(self):
        lines = []
        archives = sorted(self.archives.values(), key=lambda a: a.id)
        detectors = sorted(self.detectors.values(), key=lambda d: d.id)

        lines.append(table_archive)
        for archive in archives:
            lines.append(archive.archive_row())

        lines.append(table_detector)
        for detector in detectors:
            lines.append(detector.row())

        lines.append(table_file)
        for archive in archives:
            lines += archive.file_rows()

        return lines


class CkmameDBComparator:
    def __init__(self, filename, unzipped):
        self.directory = os.path.dirname(filename)
        self.unzipped = unzipped
        self.got_db = None
        self.got_dump = []
        self.expected_db = None
        self.hash_exceptions = {}
        self.detector_hashes = {}
        self.ignored_archives = {}
        self.status_overrides = {}
        self.process_stdin()
        self.read_ckmamedb(filename)
        self.read_directory()
        compare = nihtest.CompareArrays.CompareArrays(self.expected_db.dump(), self.got_dump)
        diffs = compare.get_diff()
        if len(diffs) > 0:
            print(os.linesep.join(diffs))
            sys.exit(1)
        else:
            sys.exit(0)

    def process_stdin(self):
        for line in sys.stdin.readlines():
            words = shlex.split(line)
            # TODO: backslash escapes
            if words[0] == "hashes":
                self.add_hash_exception(words[1:])
            elif words[0] == "detector-hashes":
                self.add_detector_hashes(words[1:])
            elif words[0] == "ignore":
                self.ignore_archive(words[1:])
            elif words[0] == "status":
                self.add_status_override(words[1:])
            else:
                print(f"unknown directive '{words[0]}'", file=sys.stderr)
                sys.exit(2)

    def add_hash_exception(self, arguments):
        archive = arguments[0]
        # TODO: strip .zip for variant unzipped
        file = arguments[1]
        hashes = {"crc", "md5", "sha1", "sha256"}
        for spec in arguments[2].split(","):
            if spec == "cheap":
                hashes = {"crc"} # TODO: {} for unzipped
            elif spec == "none":
                hashes = {}
            elif spec == "all":
                hashes = {"crc", "md5", "sha1", "sha256"}
            elif spec == "crc" or spec == "md5" or spec == "sha1" or spec == "sha256":
                hashes.add(spec)
            elif spec == "-crc" or spec == "-md5" or spec == "-sha1" or spec == "-sha256":
                hashes.remove(spec[1:])
            else:
                print(f"invalid hash spec '{spec}'")
        self.hash_exceptions[f"{archive}/{file}"] = hashes

    def add_detector_hashes(self, arguments):
        detector_name = arguments[0]
        detector_version = arguments[1]
        archive = arguments[2]
        # TODO: strip .zip for variant unzipped
        file = arguments[3]
        key = f"{archive}/{file}"
        if key not in self.detector_hashes:
            self.detector_hashes[key] = []
        self.detector_hashes[key].append(detector_name + "|" + detector_version)

    def add_status_override(self, arguments):
        archive = arguments[0]
        file = arguments[1]
        if arguments[2] in status_names:
            self.status_overrides[f"{archive}/{file}"] = status_names[arguments[2]]
        else:
            raise RuntimeError(f"unknown status name {arguments[2]}")

    def ignore_archive(self, arguments):
        self.ignored_archives[arguments[0]] = True

    def read_ckmamedb(self, filename):
        command = nihtest.Command.Command("dbdump", [filename])
        command.run()
        if command.exit_code != 0:
            print("dbdump failed:")
            print(os.linesep.join(command.stderr))
            # TODO: error if command.exit_code != 0

        table = None
        db = CkmameDB()

        for line in command.stdout:
            if line.startswith(">>> "):
                if line == table_archive:
                    table = Table.ARCHIVE
                elif line == table_detector:
                    table = Table.DETECTOR
                elif line == table_file:
                    table = Table.FILE
                else:
                    table = None
                    # TODO: error: unknown table
            else:
                fields = line.split("|")
                if table is None:
                    continue
                    # TODO: error: line outside table
                elif table == Table.ARCHIVE:
                    self.process_archive(db, fields)
                elif table == Table.DETECTOR:
                    self.process_detector(db, fields)
                elif table == Table.FILE:
                    self.process_file(db, fields)

        self.got_db = db
        self.got_dump = command.stdout

    def process_archive(self, db, fields):
        db.add_archive(Archive(fields[0], fields[1], fields[2], fields[3], fields[4]))

    def process_detector(self, db, fields):
        db.add_detector(Detector(fields[0], fields[1], fields[2]))

    def process_file(self, db, fields):
        archive = db.get_archive_by_id(int(fields[0]))
        if archive is None:
            # TODO error
            0 == 0
        else:
            archive.files.append(File(fields[2], fields[1], fields[3], fields[4], fields[5], fields[6], fields[7], fields[8], fields[9], fields[10]))

    def read_directory(self):
        arguments = ["-o", "/dev/stdout", "--runtest"]
        if self.unzipped:
            arguments.append("--roms-unzipped")
        arguments.append(self.directory)
        command = nihtest.Command.Command("mkmamedb", arguments)
        command.run()
        if command.exit_code != 0:
            print("mkmamedb failed:")
            print(os.linesep.join(command.stderr))
            # TODO: error if command.exit_code != 0

        db = CkmameDB()
        db.archive_index.add_archives(self.got_db.archives.values())
        archive = None
        prefix = ""

        detector_ids = {}

        for line in command.stdout:
            if line[0] == '#':
                continue
            words = line.split(" ")
            if len(words) < 2:
                # TODO: invalid line
                continue

            name = words[0]
            if name == ".":
                continue
            if name[0:2] == "./":
                name = name[2:]
            name = destrsvis(name)

            attributes = {}
            for attribute in words[1:]:
                match = re.match(re_attribute, attribute)
                # TODO: error if not matched
                attributes[match.group(1)] = match.group(2)

            if attributes["type"] == "dir":
                prefix = name + "/"
                archive_name = name if name != "" else "."

                if archive_name in self.ignored_archives:
                    archive = None
                    continue

                info = os.stat(os.path.join(self.directory, name))

                archive = Archive(self.got_db.get_archive_id(archive_name), archive_name,
                                  int(info.st_mtime) if name != "." else "0",
                                  info.st_size if stat.S_ISREG(info.st_mode) else 0)
                db.add_archive(archive)
                got_archive = self.got_db.get_archive_by_id(archive.id)
                if got_archive is not None:
                    archive.file_index.add_files(got_archive.files)

            elif attributes["type"] == "file":
                if archive is None:
                    continue

                if name[0:len(prefix)] == prefix:
                    name = name[len(prefix):]

                hashes = self.hashes(archive_name, name)

                file = File(name)

                if "size" in attributes:
                    file.size = attributes["size"]
                if "md5" in attributes and "md5" in hashes:
                    file.md5 = "<" + attributes["md5"] + ">"
                if "sha1" in attributes and "sha1" in hashes:
                    file.sha1 = "<" + attributes["sha1"] + ">"
                if "sha256" in attributes and "sha256" in hashes:
                    file.sha256 = "<" + attributes["sha256"] + ">"
                if "status" in attributes:
                    if attributes["status"] in status_names:
                        file.status = status_names[attributes["status"]]
                    else:
                        print(f"unknown status {attributes['status']}")
                file.mtime = attributes["time"]
                if "crc" in attributes and "crc" in hashes:
                    file.crc = int(attributes["crc"], 16)
                status_override = self.get_status_override(archive_name, name)
                if status_override is not None:
                    file.status = status_override
                archive.add_file(file)

                detector_hashes = self.get_detector_hashes(archive_name, name)
                for detector_hash in detector_hashes:
                    if detector_hash not in detector_ids:
                        detector_id = self.got_db.get_detector_id(detector_hash)
                        detector_ids[detector_hash] = detector_id
                        (name, version) = detector_hash.split("|")
                        db.add_detector(Detector(detector_id, name, version))
                    else:
                        detector_id = detector_ids[detector_hash]
                    detector_file = File("", file.index, size="<???>", crc="<???>", md5="<???>", sha1="<???>", sha256="<???>")
                    detector_file.detector_id = detector_id
                    detector_file.status = 0
                    if got_archive is not None:
                        got_file = got_archive.get_detector_hash(file.index, detector_id)
                        if got_file is not None:
                            detector_file.mtime = got_file.mtime
                            detector_file.status = got_file.status
                            detector_file.size = got_file.size
                            detector_file.crc = got_file.crc
                            detector_file.md5 = got_file.md5
                            detector_file.sha1 = got_file.sha1
                            detector_file.sha256 = got_file.sha256
                    archive.add_detector_hash(detector_file)

        self.expected_db = db

    def hashes(self, archive_name, file_name):
        for key in (f"{archive_name}/{file_name}", f"{archive_name}/*"):
            if key in self.hash_exceptions:
                return self.hash_exceptions[key]
        else:
            return {"crc", "md5", "sha1", "sha256"}

    def get_status_override(self, archive_name, file_name):
        for key in (f"{archive_name}/{file_name}", f"{archive_name}/*"):
            if key in self.status_overrides:
                return self.status_overrides[key]
        else:
            return None

    def get_detector_hashes(self, archive_name, file_name):
        for key in (f"{archive_name}/{file_name}", f"{archive_name}/*"):
            if key in self.detector_hashes:
                return self.detector_hashes[key]
        else:
            return []


def destrsvis(str):
    # TODO
    return str


def main():
    parser = argparse.ArgumentParser(
        prog='compare-ckmamedb',
        description="compare .ckmamedb files")
    parser.add_argument('ckmamedb', help='.ckmamedb file')
    parser.add_argument("-u", "--unzipped", action='store_true', help="ROMs are unzipped in file system")

    args = parser.parse_args()

    CkmameDBComparator(args.ckmamedb, args.unzipped)
    return 0


if __name__ == "__main__":
    sys.exit(main())
