# Copyright (c) DataLab Platform Developers, BSD 3-Clause license, see LICENSE file.

"""
Module providing test utilities
"""

from __future__ import annotations

import atexit
import functools
import os
import os.path as osp
import pathlib
import subprocess
import sys
import tempfile
from collections.abc import Callable
from typing import Any

import numpy as np
from guidata.configtools import get_module_data_path

from cdl.config import MOD_NAME
from cdl.env import execenv

TST_PATH = []


def add_test_path(path: str) -> None:
    """Appends test data path, after normalizing it and making it absolute.
    Do nothing if the path is already in the list.

    Args:
        Path to add to the list of test data paths

    Raises:
        FileNotFoundError: if the path does not exist
    """
    path = osp.abspath(osp.normpath(path))
    if path not in TST_PATH:
        if not osp.exists(path):
            raise FileNotFoundError(f"Test data path does not exist: {path}")
        TST_PATH.append(path)


def add_test_path_from_env(envvar: str) -> None:
    """Appends test data path from environment variable (fails silently)"""
    # Note: this function is used in third-party plugins
    path = os.environ.get(envvar)
    if path:
        add_test_path(path)


# Add test data files and folders pointed by `CDL_DATA` environment variable:
add_test_path_from_env("CDL_DATA")


def add_test_module_path(modname: str, relpath: str) -> None:
    """
    Appends test data path relative to a module name.
    Used to add module local data that resides in a module directory
    but will be shipped under sys.prefix / share/ ...

    modname must be the name of an already imported module as found in
    sys.modules
    """
    add_test_path(get_module_data_path(modname, relpath=relpath))


# Add test data files and folders for the DataLab module:
add_test_module_path(MOD_NAME, osp.join("data", "tests"))


def get_test_fnames(pattern: str, in_folder: str | None = None) -> list[str]:
    """
    Return the absolute path list to test files with specified pattern

    Pattern may be a file name (basename), a wildcard (e.g. *.txt)...

    Args:
        pattern: pattern to match
        in_folder: folder to search in, in test data path (default: None,
         search in all test data paths)
    """
    pathlist = []
    for pth in [osp.join(TST_PATH[0], in_folder)] if in_folder else TST_PATH:
        pathlist += sorted(pathlib.Path(pth).rglob(pattern))
    if not pathlist:
        raise FileNotFoundError(f"Test file(s) {pattern} not found")
    return [str(path) for path in pathlist]


def try_open_test_data(title: str, pattern: str) -> Callable:
    """Decorator handling test data opening"""

    def try_open_test_data_decorator(func: Callable) -> Callable:
        """Decorator handling test data opening"""

        @functools.wraps(func)
        def func_wrapper() -> None:
            """Decorator wrapper function"""
            execenv.print(title + ":")
            execenv.print("-" * len(title))
            try:
                for fname in get_test_fnames(pattern):
                    execenv.print(f"=> Opening: {fname}")
                    func(fname)
            except FileNotFoundError:
                execenv.print(f"  No test data available for {pattern}")
            finally:
                execenv.print(os.linesep)

        return func_wrapper

    return try_open_test_data_decorator


def get_default_test_name(suffix: str | None = None) -> str:
    """Return default test name based on script name"""
    name = osp.splitext(osp.basename(sys.argv[0]))[0]
    if suffix is not None:
        name += "_" + suffix
    return name


def get_output_data_path(extension: str, suffix: str | None = None) -> str:
    """Return full path for data file with extension, generated by a test script"""
    name = get_default_test_name(suffix)
    return osp.join(TST_PATH[0], f"{name}.{extension}")


class CDLTemporaryDirectory(tempfile.TemporaryDirectory):
    """DataLab's temporary directory class that ignores errors when cleaning up,
    and restores the current working directory after cleanup"""

    def __init__(self) -> None:
        super().__init__()
        self.__cwd = os.getcwd()

    def cleanup(self) -> None:
        """Cleanup temporary directory and ignore errors"""
        os.chdir(self.__cwd)
        try:
            super().cleanup()
        except (PermissionError, RecursionError):
            pass


def get_temporary_directory() -> str:
    """Return path to a temporary directory, and clean-up at exit"""
    tmp = CDLTemporaryDirectory()
    atexit.register(tmp.cleanup)
    return tmp.name


def exec_script(
    path: str,
    wait: bool = True,
    args: list[str] = None,
    env: dict[str, str] | None = None,
) -> None:
    """Run test script.

    Args:
        path (str): path to script
        wait (bool): wait for script to finish
        args (list): arguments to pass to script
        env (dict): environment variables to pass to script
    """
    command = [sys.executable, '"' + path + '"'] + ([] if args is None else args)
    stderr = subprocess.DEVNULL if execenv.unattended else None
    # pylint: disable=consider-using-with
    proc = subprocess.Popen(" ".join(command), shell=True, stderr=stderr, env=env)
    if wait:
        proc.wait()


def get_script_output(
    path: str, args: list[str] = None, env: dict[str, str] | None = None
) -> str:
    """Run test script and return its output.

    Args:
        path (str): path to script
        args (list): arguments to pass to script
        env (dict): environment variables to pass to script

    Returns:
        str: script output
    """
    command = [sys.executable, '"' + path + '"'] + ([] if args is None else args)
    result = subprocess.run(
        " ".join(command), capture_output=True, text=True, env=env, check=False
    )
    return result.stdout.strip()


def compare_lists(list1: list, list2: list, level: int = 1) -> bool:
    """Compare two lists

    Args:
        list1: first list
        list2: second list
        level: recursion level

    Returns:
        True if lists are the same, False otherwise
    """
    same = True
    prefix = "  " * level
    for idx, (elem1, elem2) in enumerate(zip(list1, list2)):
        execenv.print(f"{prefix}Checking element {idx}...", end=" ")
        if isinstance(elem1, (list, tuple)):
            execenv.print("")
            same = same and compare_lists(elem1, elem2, level + 1)
        elif isinstance(elem1, dict):
            execenv.print("")
            same = same and compare_metadata(elem1, elem2, level + 1)
        else:
            same_value = str(elem1) == str(elem2)
            if not same_value:
                execenv.print(f"Different values: {elem1} != {elem2}")
            same = same and same_value
            execenv.print("OK" if same else "KO")
    return same


def compare_metadata(
    dict1: dict[str, Any], dict2: dict[str, Any], level: int = 1
) -> bool:
    """Compare metadata dictionaries without private elements

    Args:
        dict1: first dictionary, exclusively with string keys
        dict2: second dictionary, exclusively with string keys
        level: recursion level

    Returns:
        True if metadata is the same, False otherwise
    """
    dict_a, dict_b = dict1.copy(), dict2.copy()
    for dict_ in (dict_a, dict_b):
        for key in list(dict_.keys()):
            if key.startswith("__"):
                dict_.pop(key)
    same = True
    prefix = "  " * level
    for key in dict_a:
        if key not in dict_b:
            same = False
            break
        val_a, val_b = dict_a[key], dict_b[key]
        execenv.print(f"{prefix}Checking key {key}...", end=" ")
        if isinstance(val_a, dict):
            execenv.print("")
            same = same and compare_metadata(val_a, val_b, level + 1)
        elif isinstance(val_a, (list, tuple)):
            execenv.print("")
            same = same and compare_lists(val_a, val_b, level + 1)
        else:
            same_value = str(val_a) == str(val_b)
            if not same_value:
                execenv.print(f"Different values for key {key}: {val_a} != {val_b}")
            same = same and same_value
            execenv.print("OK" if same else "KO")
    return same


def __array_to_str(data: np.ndarray) -> str:
    """Return a compact description of the array properties"""
    dims = "×".join(str(dim) for dim in data.shape)
    return f"{dims},{data.dtype},{data.min():.2g}→{data.max():.2g},µ={data.mean():.2g}"


def check_array_result(
    title: str,
    res: np.ndarray,
    exp: np.ndarray,
    rtol: float = 1.0e-5,
    atol: float = 1.0e-8,
) -> None:
    """Assert that two arrays are almost equal."""
    restxt = f"{title}: {__array_to_str(res)} (expected: {__array_to_str(exp)})"
    execenv.print(restxt)
    assert np.allclose(res, exp, rtol=rtol, atol=atol), restxt


def check_scalar_result(
    title: str,
    res: float,
    exp: float | tuple[float, ...],
    rtol: float = 1.0e-5,
    atol: float = 1.0e-8,
) -> None:
    """Assert that two scalars are almost equal."""
    restxt = f"{title}: {res} (expected: {exp})"
    execenv.print(restxt)
    if isinstance(exp, tuple):
        assert any(
            np.isclose(res, exp_val, rtol=rtol, atol=atol) for exp_val in exp
        ), restxt
    else:
        assert np.isclose(res, exp, rtol=rtol, atol=atol), restxt
