#! /usr/bin/env python3

"""
IN330 RISC-V assembler/emulator project -- test system.
Runs the assembler/emulator on all test files and report errors in the output.

Environment variables (all optional):
  BASE_DIR          Folder in which tests/**/*.s is searched [.]
  RISCV_ASSEMBLER   Path to assembler program [$BASE_DIR/riscv-assembler]
  RISCV_EMULATOR    Path to emulator program [$BASE_DIR/riscv-emulator]
  TEST_FILES        Glob of files to test [$BASE_DIR/tests/**/*.s]

For the assembler, reference output is obtained by running a compiler followed
by objcopy; the PATH is searched for the following executables, which can also
be specified with environment variables:
  1. clang, clang-14,clang-15 $TOOL_CLANG
  2. riscv64-elf-gcc, $TOOL_GCC
  3. llvm-objcopy, llvm-objcopy-14, llvm-objcopy-14, risv64-elf-objcopy, $TOOL_OBJCOPY
"""

import subprocess
import shutil
import struct
import glob
import sys
import os
import re

### Test file selection ###

BASE_DIR = os.environ.get("BASE_DIR", ".")
RISCV_ASSEMBLER = os.path.join(BASE_DIR, "riscv-assembler")
RISCV_EMULATOR = os.path.join(BASE_DIR, "riscv-emulator")

if "TEST_FILES" in os.environ:
    ALL_FILES = glob.glob(os.environ["TEST_FILES"], recursive=True)
else:
    ALL_FILES = glob.glob(
        os.path.join(BASE_DIR, "tests/**/*.s"), recursive=True)

# Remove ./ from paths for cleaner outputs
ALL_FILES = [f[2:] if f.startswith("./") else f for f in ALL_FILES]

### General utilities ###

class TestFailure(Exception):
    pass
class TestSkipped(Exception):
    pass

GREEN  = lambda s: "\x1b[32;1m" + s + "\x1b[0m"
RED    = lambda s: "\x1b[31;1m" + s + "\x1b[0m"
YELLOW = lambda s: "\x1b[33;1m" + s + "\x1b[0m"
PURPLE = lambda s: "\x1b[35m" + s + "\x1b[0m"
CYAN   = lambda s: "\x1b[36m" + s + "\x1b[0m"

def print_with_header(string, header):
    lines = string.strip("\n").split("\n")
    print("\n".join(header + " " + l for l in lines))

### Resolution of external dependencies ###

# Find the first program in PATH from the provided list
def find_program(env_var, options):
    if env_var in os.environ:
        return os.environ.get(env_var)
    else:
        return next((o for o in options if shutil.which(o) is not None), None)

# Some people have no global name for LLVM toolchains. Feel free to add to the
# end of the lists to match your setup.
CLANG = find_program("TOOL_CLANG",
    ["clang", "clang-14", "clang-15"])
OBJCOPY = find_program("TOOL_OBJCOPY",
    ["llvm-objcopy", "llvm-objcopy-14",
     "llvm-objcopy-15", "riscv64-elf-objcopy", "riscv64-unknown-elf-objcopy"])
RV64_GCC = find_program("TOOL_GCC",
    ["riscv64-elf-gcc", "riscv64-unknown-elf-gcc"])

if CLANG is None and RV64_GCC is None:
    print(RED("error:"), "no RISC-V compiler to test the assembler!")
    sys.exit(0)
if OBJCOPY is None:
    print(RED("error:"), "no (llvm-)objcopy to test the assembler!")
    sys.exit(0)

### Test utilities ###

# Split into non-empty lines, remove spaces, translate to lower case
def normalize_and_split_lines(s):
    s = [re.sub(r"[ \t\n]+", "", line.lower()) for line in s.split("\n")]
    return [line for line in s if line]

riscv_regs = [
    "zero", "ra", "sp",  "gp",  "tp", "t0", "t1", "t2",
    "s0",   "s1", "a0",  "a1",  "a2", "a3", "a4", "a5",
    "a6",   "a7", "s2",  "s3",  "s4", "s5", "s6", "s7",
    "s8",   "s9", "s10", "s11", "t3", "t4", "t5", "t6",
]

# Parse string into a final state dictionary
def parse_final_state(s):
    s = normalize_and_split_lines(s)
    state = dict()
    for line in s:
        reg, value = line.split(":", 1)
        reg = reg.strip()

        m = re.fullmatch(r"x(\d)+", reg)
        if m is not None:
            regnum = int(reg[1:])
            if regnum > 31:
                raise TestFailure("Unknown register '{}' in output".format(reg))
        elif reg in riscv_regs:
            regnum = riscv_regs.index(reg)
        else:
            raise TestFailure("Unknown register '{}' in output".format(reg))
        state[regnum] = int(value.strip(), 0)
    return state

def assert_equal_hex(ref, out):
    if len(out) != len(ref):
        raise TestFailure("Output should contain {} instructions, but has {}"
                          .format(len(ref), len(out)))

    msg = ""
    failed = False

    for i in range(len(ref)):
        if out[i] != ref[i]:
            msg += "Instruction #{} should be {:08x}, but is {:08x}\n" \
                   .format(i+1, ref[i], out[i])
            failed = True

    if failed:
        raise TestFailure(msg)

def assert_equal_regs(ref, out):
    ref = parse_final_state(ref)
    out = parse_final_state(out)

    msg = ""
    failed = False

    for r in range(32):
        if ref.get(r, 0) != out.get(r, 0):
            msg += "\n- x{} ({}) should be {}, but it is {}".format(
                r, riscv_regs[r], ref.get(r, 0), out.get(r, 0))
            failed = True

    if failed:
        raise TestFailure(msg)

def convert_gcc_syntax(input, output):
    with open(input, "r") as fp:
        asm = fp.read()

    # Remove comments
    def remove_comment(line):
        if "#" in line:
            return line[:line.index("#")]
        return line
    lines = [remove_comment(l) for l in asm.splitlines()]
    asm = "\n".join(l for l in lines if l) + "\n"

    RE_JUMP = re.compile(r'\b(j|jal|beq|bne|blt|bge)\b\s*([^\n]+)', re.I)
    def rep(m):
        args = m[2].split(",")
        args = args[:-1] + ["(.+ " + args[-1] + ")"]
        return m[1] + " " + ", ".join(args)

    asm = RE_JUMP.sub(rep, asm)
    with open(output, "w") as fp:
        fp.write(asm)

def assembler_reference(file):
    prog = os.path.splitext(file)[0]
    prog_obj = prog + ".o"
    prog_bin = prog + ".bin"

    if RV64_GCC is not None:
        file_gcc = file[:-2] + ".s.gnu"
        convert_gcc_syntax(file, file_gcc)
        subprocess.run(
            [RV64_GCC, "-march=rv64i", "-mabi=lp64", "-x", "assembler",
             "-c", file_gcc, "-o", prog_obj],
            check=True)
    else:
        subprocess.run(
            [CLANG, "--target=riscv64", "-march=rv64g", "-c", file, "-o",
            prog_obj],
            check=True)
    subprocess.run(
        [OBJCOPY, "-O", "binary", "-j", ".text", prog_obj,
         prog_bin],
        check=True)

    with open(prog_bin, "rb") as fp:
        data = fp.read()
    assert len(data) % 4 == 0
    return struct.unpack("<{}I".format(len(data) // 4), data)

def get_expected(file):
    expected = ""
    inside = False

    with open(file, encoding="utf-8") as fp:
        for line in fp.readlines():
            # Ignore non-comments
            if not re.match(r"\s*#", line):
                continue
            # Cleanup comment start and whitespaces
            line = re.sub(r"\s*#\s*", "", line)
            line = re.sub(r"\s*$", "", line)

            if line == "END":
                inside = False
            elif line == "EXPECTED":
                inside = True
            elif inside:
                expected += "\n" + line

    # Make it None if it's an empty string
    return expected or None

### Test system ###

class Tester:
    def __init__(self):
        self.passed, self.skipped, self.failed = 0, 0, 0

    def run_test(self, tag, name, f):
        print(tag, name + "... ", end="")
        sys.stdout.flush()
        try:
            f()
            print(GREEN("OK"))
            self.passed += 1
        except TestSkipped as e:
            print(YELLOW("SKIPPED"))
            print_with_header(str(e), YELLOW("|"))
            self.skipped += 1
        except TestFailure as e:
            print(RED("FAILED"))
            print_with_header(str(e), RED("|"))
            self.failed += 1

    def print_summary(self):
        total = self.passed + self.skipped + self.failed
        print("\nTotal: {}/{}".format(self.passed, total),
              GREEN("passed"), end="")
        if self.skipped > 0:
            print(",", self.skipped, YELLOW("skipped"), end="")
        if self.failed > 0:
            print(",", self.failed, RED("failed"), end="")
        print("")

    def test_assembler(self, filename):
        prog = os.path.splitext(filename)[0]
        prog_hex = prog + ".hex"
        reference = assembler_reference(filename)

        if os.path.exists(prog_hex):
            os.remove(prog_hex)

        # Run the command, ignoring I/Os (we only use output files)
        try:
            rc = subprocess.run(
                [RISCV_ASSEMBLER, filename, prog_hex],
                stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
                timeout=3)
        except subprocess.TimeoutExpired:
            raise TestFailure("timeout expired (likely infinite loop)")

        if rc.returncode != 0:
            raise TestFailure("exit code {}".format(rc.returncode))

        produced = []
        with open(prog_hex, "r") as fp:
            for line in fp.read().splitlines():
                try:
                    produced.append(int(line, 16))
                except ValueError:
                    pass

        assert_equal_hex(reference, produced)

    def test_emulator(self, filename):
        prog, _ = os.path.splitext(filename)
        prog_s = prog + ".s"
        prog_hex = prog + ".ref.hex"
        prog_state = prog + ".state"

        expected = get_expected(filename)

        ref_hex = assembler_reference(prog_s)
        with open(prog_hex, "w") as fp:
            for h in ref_hex:
                fp.write("{:08x}\n".format(h))

        if os.path.exists(prog_state):
            os.remove(prog_state)

        # Run the command, ignoring I/Os (we only use output files)
        try:
            rc = subprocess.run(
                [RISCV_EMULATOR, prog_hex, prog_state],
                stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
                timeout=3)
        except subprocess.TimeoutExpired:
            raise TestFailure("timeout expired (likely infinite loop)")

        if rc.returncode != 0:
            raise TestFailure("exit code {}".format(rc.returncode))

        if expected is None:
            raise TestSkipped("no EXPECTED block in test file")
        if not os.path.exists(prog_state):
            raise TestFailure(".state file does not exist")
        with open(prog_state, "r") as fp:
            state = fp.read()
        if state == "":
            raise TestFailure(".state file is empty")
        assert_equal_regs(expected, state)

    def run_all_tests(self):
        if os.path.exists(RISCV_ASSEMBLER):
            for f in ALL_FILES:
                self.run_test(PURPLE("<assembler>"), f,
                    lambda: self.test_assembler(f))
        if os.path.exists(RISCV_EMULATOR):
            for f in ALL_FILES:
                self.run_test(CYAN("<emulator>"), f,
                    lambda: self.test_emulator(f))

t = Tester()
try:
    t.run_all_tests()
except KeyboardInterrupt:
    print("\n" + YELLOW("Tests interrupted!"))
t.print_summary()
