#!/usr/bin/env python3

# Test every patch from files folder and output error on failure
#
# Copyright (C) 2016 Intel Corporation
#
# SPDX-License-Identifier: GPL-2.0-only

import os
import subprocess
import sys

currentdir = os.path.dirname(os.path.abspath(__file__))
patchesdir = os.path.join(currentdir, 'files')
topdir     = os.path.dirname(currentdir)
parentdir  = os.path.dirname(topdir)

# path to the repo root
repodir = os.path.dirname(os.path.dirname(parentdir))

def print_results(counts):
    total = sum(counts.values())
    print("============================================================================")
    print("Testsuite summary for %s" % os.path.basename(topdir))
    print("============================================================================")
    print("# TOTAL: " + str(total))
    print("# XPASS: " + str(counts["xpass"]))
    print("# XFAIL: " + str(counts["xfail"]))
    print("# XSKIP: " + str(counts["xskip"]))
    print("# PASS: " + str(counts["pass"]))
    print("# FAIL: " + str(counts["fail"]))
    print("# SKIP: " + str(counts["skip"]))
    print("# ERROR: " + str(counts["error"]))
    print("============================================================================")

def get_patches(patchesdir):
    """
    Return a list of dict mapping test IDs to patch filenames and expected results.
    """
    patch_list = []
    for root, dirs, patches in os.walk(patchesdir):
        for patch in patches:
            part = patch.split('.')
            klass, testname, expected_result = part[0], part[1], part[-1]
            testid = f".{klass}.{testname}"
            patch_list.append({
                "testid": testid,
                "patch": patch,
                "expected": expected_result,
                "root" : root,
            })
    return patch_list

def analyze_result(results, patch, counts, return_code):
    testid   = patch["testid"]
    expected_result = str(patch["expected"])
    for resultline in results.splitlines():
        if testid in resultline:
            result, _ = resultline.split(':', 1)

            if expected_result.upper() == "FAIL" and result.upper() == "FAIL" and return_code != 0:
                    counts["xfail"] = counts["xfail"] + 1
                    print("XFAIL: %s (file: %s)" % (testid.strip("."), os.path.basename(patch["patch"])))
            elif expected_result.upper() == "PASS" and result.upper() == "PASS" and return_code == 0:
                    counts["xpass"] = counts["xpass"] + 1
                    print("XPASS: %s (file: %s)" % (testid.strip("."), os.path.basename(patch["patch"])))
            elif expected_result.upper() == "SKIP" and result.upper() == "SKIP" and return_code == 0:
                    counts["xskip"] = counts["xskip"] + 1
                    print("XSKIP: %s (file: %s)" % (testid.strip("."), os.path.basename(patch["patch"])))
            else:
                print("%s: %s (%s)" % (result.upper(), testid.strip("."),os.path.basename(patch["patch"])))
                if result.upper() == "PASS":
                    counts["pass"] = counts["pass"] + 1
                elif result.upper() == "FAIL":
                    counts["fail"] = counts["fail"] + 1
                elif result.upper() == "SKIP":
                    counts["skip"] = counts["skip"] + 1
                else:
                     print("Bad result on test %s against %s" % (testid.strip("."),os.path.basename(patch["patch"])))
                     counts["error"] = counts["error"] + 1
            break
    else:
        print ("No test for=%s" % patch["patch"])

    return counts

def run_sh(cmd):
    """Run a shell command and return its stdout as a stripped string."""
    return subprocess.check_output(cmd, cwd=currentdir, stderr=subprocess.STDOUT, universal_newlines=True, shell=True).strip()

def get_git_state():
    """Return the current Git HEAD state (branch, commit)."""
    try:
        inside_repo = run_sh("git rev-parse --is-inside-work-tree")
    except subprocess.CalledProcessError:
        print("Not a Git repository")
        return None

    state = {
        "branch": run_sh("git rev-parse --abbrev-ref HEAD"),
        "commit": run_sh("git rev-parse HEAD"),
    }

    return state

def restore_git_state(git_state):
    assert git_state['branch'] is not None, "Failed to restore git state, no valid branch"
    if git_state['branch'] == "HEAD":
        run_sh(f"git switch --detach {git_state['commit']}")
    else:
        run_sh(f"git switch {git_state['branch']}")

def is_git_state_same(before, after):
    ret = True

    for k in ("branch", "commit"):
        if before[k] != after[k]:
            print(f"Git state changed: {k} changed: {before[k]} -> {after[k]}")
            ret = False

    return ret

def git_attach_head(temp_branch):
    run_sh(f"git switch -C {temp_branch}")

def git_detach_head():
    run_sh("git switch --detach HEAD")
    assert run_sh("git rev-parse --abbrev-ref HEAD") == "HEAD", "Failed to enter detached HEAD state"

    return get_git_state()

# Once the tests are in oe-core, we can remove the testdir param and use os.path.dirname to get relative paths
def test(root, patch):
    res = True
    patchpath = os.path.abspath(os.path.join(root, patch))
    
    cmd     = 'patchtest --base-commit HEAD --repodir %s --testdir %s/tests --patch %s' % (repodir, topdir, patchpath)
    results = subprocess.run(cmd, capture_output=True, universal_newlines=True, shell=True)

    return results.returncode, results.stdout

def test_head_attached(patches, counts, branch):

    git_attach_head(branch)
    git_state_before = get_git_state()
    for patch_info in patches:
        return_code, results = test(patch_info["root"], patch_info["patch"])
        counts = analyze_result(results, patch_info, counts, return_code)
    git_state_after = get_git_state()
    assert is_git_state_same(git_state_before, git_state_after), "Repository state changed after attached HEAD test."
    return counts

def test_head_detached(patches, counts):
    git_state = get_git_state()
    git_st_detach_before = git_detach_head()
    patch_info = patches[0]
    testid   = patch_info["testid"]
    return_code, results = test(patch_info["root"], patch_info["patch"])
    git_st_detach_after = get_git_state()
    counts = analyze_result(results, patch_info, counts, return_code)
    if not is_git_state_same(git_st_detach_before, git_st_detach_after):
        print(" Test '%s' failed with git in detach HEAD mode: state changed after test" % testid.strip("."))
        counts["error"] = counts["error"] + 1
    else:
        counts["xpass"] = counts["xpass"] + 1
        print("XPASS: %s.test_head_detached" % os.path.basename(__file__))

    return counts

def run_tests(patches, counts):
    temp_branch = "test_patchtest_head_attached"
    git_state = get_git_state()
    assert git_state['branch'] != temp_branch, f"Cannot run patchtest selftest while on branch '{temp_branch}'"
    counts = test_head_attached(patches, counts, temp_branch)
    counts = test_head_detached(patches, counts)
    restore_git_state(git_state)
    run_sh(f"git branch -D {temp_branch}")

    return counts

if __name__ == '__main__':
    counts = {
        "pass": 0,
        "fail": 0,
        "skip": 0,
        "xpass": 0,
        "xfail": 0,
        "xskip": 0,
        "error": 0,
    }

    results = None

    patches = get_patches(patchesdir)
    if not patches:
        print(f"Error: Unable to find patch(es) in {patchesdir}")
        sys.exit(1)
    counts = run_tests(patches, counts)
    print_results(counts)
