import itertools
import re
import time
import warnings

import joblib
import numpy as np
import pytest
from numpy.testing import assert_array_equal

from sklearn import config_context, get_config
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.fixes import _IS_WASM
from sklearn.utils.parallel import Parallel, delayed


def get_working_memory():
    return get_config()["working_memory"]


@pytest.mark.parametrize("n_jobs", [1, 2])
@pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"])
def test_configuration_passes_through_to_joblib(n_jobs, backend):
    # Tests that the global global configuration is passed to joblib jobs

    with config_context(working_memory=123):
        results = Parallel(n_jobs=n_jobs, backend=backend)(
            delayed(get_working_memory)() for _ in range(2)
        )

    assert_array_equal(results, [123] * 2)


def test_parallel_delayed_warnings():
    """Informative warnings should be raised when mixing sklearn and joblib API"""
    # We should issue a warning when one wants to use sklearn.utils.fixes.Parallel
    # with joblib.delayed. The config will not be propagated to the workers.
    warn_msg = "`sklearn.utils.parallel.Parallel` needs to be used in conjunction"
    with pytest.warns(UserWarning, match=warn_msg) as records:
        Parallel()(joblib.delayed(time.sleep)(0) for _ in range(10))
    assert len(records) == 10

    # We should issue a warning if one wants to use sklearn.utils.fixes.delayed with
    # joblib.Parallel
    warn_msg = (
        "`sklearn.utils.parallel.delayed` should be used with "
        "`sklearn.utils.parallel.Parallel` to make it possible to propagate"
    )
    with pytest.warns(UserWarning, match=warn_msg) as records:
        joblib.Parallel()(delayed(time.sleep)(0) for _ in range(10))
    assert len(records) == 10


@pytest.mark.parametrize("n_jobs", [1, 2])
def test_dispatch_config_parallel(n_jobs):
    """Check that we properly dispatch the configuration in parallel processing.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/25239
    """
    pd = pytest.importorskip("pandas")
    iris = load_iris(as_frame=True)

    class TransformerRequiredDataFrame(StandardScaler):
        def fit(self, X, y=None):
            assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
            return super().fit(X, y)

        def transform(self, X, y=None):
            assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
            return super().transform(X, y)

    dropper = make_column_transformer(
        ("drop", [0]),
        remainder="passthrough",
        n_jobs=n_jobs,
    )
    param_grid = {"randomforestclassifier__max_depth": [1, 2, 3]}
    search_cv = GridSearchCV(
        make_pipeline(
            dropper,
            TransformerRequiredDataFrame(),
            RandomForestClassifier(n_estimators=5, n_jobs=n_jobs),
        ),
        param_grid,
        cv=5,
        n_jobs=n_jobs,
        error_score="raise",  # this search should not fail
    )

    # make sure that `fit` would fail in case we don't request dataframe
    with pytest.raises(AssertionError, match="X should be a DataFrame"):
        search_cv.fit(iris.data, iris.target)

    with config_context(transform_output="pandas"):
        # we expect each intermediate steps to output a DataFrame
        search_cv.fit(iris.data, iris.target)

    assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()


def raise_warning():
    warnings.warn("Convergence warning", ConvergenceWarning)


def _yield_n_jobs_backend_combinations():
    n_jobs_values = [1, 2]
    backend_values = ["loky", "threading", "multiprocessing"]
    for n_jobs, backend in itertools.product(n_jobs_values, backend_values):
        if n_jobs == 2 and backend == "loky":
            # XXX Mark thread-unsafe to avoid:
            # RuntimeError: The executor underlying Parallel has been shutdown.
            # See https://github.com/joblib/joblib/issues/1743 for more details.
            yield pytest.param(n_jobs, backend, marks=pytest.mark.thread_unsafe)
        else:
            yield n_jobs, backend


@pytest.mark.parametrize("n_jobs, backend", _yield_n_jobs_backend_combinations())
def test_filter_warning_propagates(n_jobs, backend):
    """Check warning propagates to the job."""
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        with pytest.raises(ConvergenceWarning):
            Parallel(n_jobs=n_jobs, backend=backend)(
                delayed(raise_warning)() for _ in range(2)
            )


def get_warning_filters():
    # In free-threading Python >= 3.14, warnings filters are managed through a
    # ContextVar and warnings.filters is not modified inside a
    # warnings.catch_warnings context. You need to use warnings._get_filters().
    # For more details, see
    # https://docs.python.org/3.14/whatsnew/3.14.html#concurrent-safe-warnings-control
    filters_func = getattr(warnings, "_get_filters", None)
    return filters_func() if filters_func is not None else warnings.filters


def test_check_warnings_threading():
    """Check that warnings filters are set correctly in the threading backend."""
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        main_warning_filters = get_warning_filters()

        assert ("error", None, ConvergenceWarning, None, 0) in main_warning_filters

        all_worker_warning_filters = Parallel(n_jobs=2, backend="threading")(
            delayed(get_warning_filters)() for _ in range(2)
        )

        def normalize_main_module(filters):
            # In Python 3.14 free-threaded, there is a small discrepancy main
            # warning filters have an entry with module = "__main__" whereas it
            # is a regex in the workers
            return [
                (
                    action,
                    message,
                    type_,
                    module
                    if "__main__" not in str(module)
                    or not isinstance(module, re.Pattern)
                    else module.pattern,
                    lineno,
                )
                for action, message, type_, module, lineno in main_warning_filters
            ]

        for worker_warning_filter in all_worker_warning_filters:
            assert normalize_main_module(
                worker_warning_filter
            ) == normalize_main_module(main_warning_filters)


@pytest.mark.xfail(_IS_WASM, reason="Pyodide always use the sequential backend")
def test_filter_warning_propagates_no_side_effect_with_loky_backend():
    with warnings.catch_warnings():
        warnings.simplefilter("error", category=ConvergenceWarning)

        Parallel(n_jobs=2, backend="loky")(delayed(time.sleep)(0) for _ in range(10))

        # Since loky workers are reused, make sure that inside the loky workers,
        # warnings filters have been reset to their original value. Using joblib
        # directly should not turn ConvergenceWarning into an error.
        joblib.Parallel(n_jobs=2, backend="loky")(
            joblib.delayed(warnings.warn)("Convergence warning", ConvergenceWarning)
            for _ in range(10)
        )
