Source code for imblearn.utils.estimator_checks

"""Utils to check the samplers and compatibility with scikit-learn"""

# Adapated from scikit-learn
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT

from __future__ import division

import sys
import traceback

from collections import Counter

import pytest

import numpy as np
from scipy import sparse

from sklearn.base import clone
from sklearn.datasets import make_classification
from sklearn.cluster import KMeans
from sklearn.preprocessing import label_binarize
from sklearn.utils.estimator_checks import check_estimator \
    as sklearn_check_estimator, check_parameters_default_constructible
from sklearn.utils.testing import assert_allclose
from sklearn.utils.testing import assert_raises_regex
from sklearn.utils.testing import set_random_state
from sklearn.utils.multiclass import type_of_target

from imblearn.over_sampling.base import BaseOverSampler
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
from imblearn.ensemble.base import BaseEnsembleSampler
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import NearMiss, ClusterCentroids

DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE']
SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler']
HAVE_SAMPLE_INDICES = [
    'RandomOverSampler', 'RandomUnderSampler', 'InstanceHardnessThreshold',
    'NearMiss', 'TomekLinks', 'EditedNearestNeighbours',
    'RepeatedEditedNearestNeighbours', 'AllKNN', 'OneSidedSelection',
    'CondensedNearestNeighbour', 'NeighbourhoodCleaningRule']
# FIXME: remove in 0.6
DONT_HAVE_RANDOM_STATE = ('NearMiss', 'EditedNearestNeighbours',
                          'RepeatedEditedNearestNeighbours', 'AllKNN',
                          'NeighbourhoodCleaningRule', 'TomekLinks')


def monkey_patch_check_dtype_object(name, estimator_orig):
    # check that estimators treat dtype object as numeric if possible
    rng = np.random.RandomState(0)
    X = rng.rand(40, 10).astype(object)
    y = np.array([0] * 10 + [1] * 30, dtype=np.int)
    estimator = clone(estimator_orig)
    estimator.fit(X, y)

    try:
        estimator.fit(X, y.astype(object))
    except Exception as e:
        if "Unknown label type" not in str(e):
            raise

    if name not in SUPPORT_STRING:
        X[0, 0] = {'foo': 'bar'}
        msg = "argument must be a string or a number"
        assert_raises_regex(TypeError, msg, estimator.fit, X, y)
    else:
        estimator.fit(X, y)


def _yield_sampler_checks(name, Estimator):
    yield check_target_type
    yield check_samplers_one_label
    yield check_samplers_fit
    yield check_samplers_fit_resample
    yield check_samplers_ratio_fit_resample
    yield check_samplers_sampling_strategy_fit_resample
    yield check_samplers_sparse
    yield check_samplers_pandas
    yield check_samplers_multiclass_ova
    yield check_samplers_preserve_dtype
    yield check_samplers_sample_indices


def _yield_all_checks(name, estimator):
    # trigger our checks if this is a SamplerMixin
    if hasattr(estimator, 'fit_resample'):
        for check in _yield_sampler_checks(name, estimator):
            yield check


[docs]def check_estimator(Estimator, run_sampler_tests=True): """Check if estimator adheres to scikit-learn conventions and imbalanced-learn This estimator will run an extensive test-suite for input validation, shapes, etc. Additional tests samplers if the Estimator inherits from the corresponding mixin from imblearn.base Parameters ---------- Estimator : class Class to check. Estimator is a class object (not an instance) run_sampler_tests=True : bool, default=True Will run or not the samplers tests. """ name = Estimator.__name__ # monkey patch check_dtype_object for the sampler allowing strings import sklearn.utils.estimator_checks sklearn.utils.estimator_checks.check_dtype_object = \ monkey_patch_check_dtype_object # scikit-learn common tests sklearn_check_estimator(Estimator) check_parameters_default_constructible(name, Estimator) if run_sampler_tests: for check in _yield_all_checks(name, Estimator): check(name, Estimator)
def check_target_type(name, Estimator): # should raise warning if the target is continuous (we cannot raise error) X = np.random.random((20, 2)) y = np.linspace(0, 1, 20) estimator = Estimator() # FIXME: in 0.6 set the random_state for all if name not in DONT_HAVE_RANDOM_STATE: set_random_state(estimator) with pytest.raises(ValueError, match="Unknown label type: 'continuous'"): estimator.fit_resample(X, y) # if the target is multilabel then we should raise an error rng = np.random.RandomState(42) y = rng.randint(2, size=(20, 3)) with pytest.raises(ValueError, match="'y' should encode the multiclass"): estimator.fit_resample(X, y) def check_samplers_one_label(name, Sampler): error_string_fit = "Sampler can't balance when only one class is present." sampler = Sampler() X = np.random.random((20, 2)) y = np.zeros(20) try: sampler.fit_resample(X, y) except ValueError as e: if 'class' not in repr(e): print(error_string_fit, Sampler, e) traceback.print_exc(file=sys.stdout) raise e else: return except Exception as exc: print(error_string_fit, traceback, exc) traceback.print_exc(file=sys.stdout) raise exc def check_samplers_fit(name, Sampler): sampler = Sampler() X = np.random.random((30, 2)) y = np.array([1] * 20 + [0] * 10) sampler.fit_resample(X, y) assert hasattr(sampler, 'sampling_strategy_'), \ "No fitted attribute sampling_strategy_" def check_samplers_fit_resample(name, Sampler): sampler = Sampler() X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) target_stats = Counter(y) X_res, y_res = sampler.fit_resample(X, y) if isinstance(sampler, BaseOverSampler): target_stats_res = Counter(y_res) n_samples = max(target_stats.values()) assert all(value >= n_samples for value in Counter(y_res).values()) elif isinstance(sampler, BaseUnderSampler): n_samples = min(target_stats.values()) if name == 'InstanceHardnessThreshold': # IHT does not enforce the number of samples but provide a number # of samples the closest to the desired target. assert all(Counter(y_res)[k] <= target_stats[k] for k in target_stats.keys()) else: assert all(value == n_samples for value in Counter(y_res).values()) elif isinstance(sampler, BaseCleaningSampler): target_stats_res = Counter(y_res) class_minority = min(target_stats, key=target_stats.get) assert all(target_stats[class_sample] > target_stats_res[class_sample] for class_sample in target_stats.keys() if class_sample != class_minority) elif isinstance(sampler, BaseEnsembleSampler): y_ensemble = y_res[0] n_samples = min(target_stats.values()) assert all(value == n_samples for value in Counter(y_ensemble).values()) # FIXME remove in 0.6 -> ratio will be deprecated def check_samplers_ratio_fit_resample(name, Sampler): if name not in DONT_SUPPORT_RATIO: # in this test we will force all samplers to not change the class 1 X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) sampler = Sampler() expected_stat = Counter(y)[1] if isinstance(sampler, BaseOverSampler): ratio = {2: 498, 0: 498} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat elif isinstance(sampler, BaseUnderSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat elif isinstance(sampler, BaseCleaningSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat if isinstance(sampler, BaseEnsembleSampler): ratio = {2: 201, 0: 201} sampler.set_params(ratio=ratio) X_res, y_res = sampler.fit_resample(X, y) y_ensemble = y_res[0] assert Counter(y_ensemble)[1] == expected_stat def check_samplers_sampling_strategy_fit_resample(name, Sampler): # in this test we will force all samplers to not change the class 1 X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) sampler = Sampler() expected_stat = Counter(y)[1] if isinstance(sampler, BaseOverSampler): sampling_strategy = {2: 498, 0: 498} sampler.set_params(sampling_strategy=sampling_strategy) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat elif isinstance(sampler, BaseUnderSampler): sampling_strategy = {2: 201, 0: 201} sampler.set_params(sampling_strategy=sampling_strategy) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat elif isinstance(sampler, BaseCleaningSampler): sampling_strategy = [2, 0] sampler.set_params(sampling_strategy=sampling_strategy) X_res, y_res = sampler.fit_resample(X, y) assert Counter(y_res)[1] == expected_stat if isinstance(sampler, BaseEnsembleSampler): sampling_strategy = {2: 201, 0: 201} sampler.set_params(sampling_strategy=sampling_strategy) X_res, y_res = sampler.fit_resample(X, y) y_ensemble = y_res[0] assert Counter(y_ensemble)[1] == expected_stat def check_samplers_sparse(name, Sampler): # check that sparse matrices can be passed through the sampler leading to # the same results than dense X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) X_sparse = sparse.csr_matrix(X) if isinstance(Sampler(), SMOTE): samplers = [ Sampler(random_state=0, kind=kind) for kind in ('regular', 'borderline1', 'borderline2', 'svm') ] elif isinstance(Sampler(), NearMiss): samplers = [Sampler(version=version) for version in (1, 2, 3)] elif isinstance(Sampler(), ClusterCentroids): # set KMeans to full since it support sparse and dense samplers = [ Sampler( random_state=0, voting='soft', estimator=KMeans(random_state=1, algorithm='full')) ] else: samplers = [Sampler()] for sampler in samplers: # FIXME: in 0.6 set the random_state for all if name not in DONT_HAVE_RANDOM_STATE: set_random_state(sampler) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) X_res, y_res = sampler.fit_resample(X, y) if not isinstance(sampler, BaseEnsembleSampler): assert sparse.issparse(X_res_sparse) assert_allclose(X_res_sparse.A, X_res) assert_allclose(y_res_sparse, y_res) else: for x_sp, x, y_sp, y in zip(X_res_sparse, X_res, y_res_sparse, y_res): assert sparse.issparse(x_sp) assert_allclose(x_sp.A, x) assert_allclose(y_sp, y) def check_samplers_pandas(name, Sampler): pd = pytest.importorskip("pandas") # Check that the samplers handle pandas dataframe and pandas series X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) X_pd = pd.DataFrame(X) sampler = Sampler() if isinstance(Sampler(), SMOTE): samplers = [ Sampler(random_state=0, kind=kind) for kind in ('regular', 'borderline1', 'borderline2', 'svm') ] elif isinstance(Sampler(), NearMiss): samplers = [Sampler(version=version) for version in (1, 2, 3)] else: samplers = [Sampler()] for sampler in samplers: # FIXME: in 0.6 set the random_state for all if name not in DONT_HAVE_RANDOM_STATE: set_random_state(sampler) X_res_pd, y_res_pd = sampler.fit_resample(X_pd, y) X_res, y_res = sampler.fit_resample(X, y) assert_allclose(X_res_pd, X_res) assert_allclose(y_res_pd, y_res) def check_samplers_multiclass_ova(name, Sampler): # Check that multiclass target lead to the same results than OVA encoding X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) y_ova = label_binarize(y, np.unique(y)) sampler = Sampler() # FIXME: in 0.6 set the random_state for all if name not in DONT_HAVE_RANDOM_STATE: set_random_state(sampler) X_res, y_res = sampler.fit_resample(X, y) X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova) assert_allclose(X_res, X_res_ova) if issubclass(Sampler, BaseEnsembleSampler): for batch_y, batch_y_ova in zip(y_res, y_res_ova): assert type_of_target(batch_y_ova) == type_of_target(y_ova) assert_allclose(batch_y, batch_y_ova.argmax(axis=1)) else: assert type_of_target(y_res_ova) == type_of_target(y_ova) assert_allclose(y_res, y_res_ova.argmax(axis=1)) def check_samplers_preserve_dtype(name, Sampler): X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) # Cast X and y to not default dtype X = X.astype(np.float32) y = y.astype(np.int32) sampler = Sampler() # FIXME: in 0.6 set the random_state for all if name not in DONT_HAVE_RANDOM_STATE: set_random_state(sampler) X_res, y_res = sampler.fit_resample(X, y) assert X.dtype == X_res.dtype, "X dtype is not preserved" assert y.dtype == y_res.dtype, "y dtype is not preserved" def check_samplers_sample_indices(name, Sampler): X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, weights=[0.2, 0.3, 0.5], random_state=0) sampler = Sampler() sampler.fit_resample(X, y) if name in HAVE_SAMPLE_INDICES: assert hasattr(sampler, 'sample_indices_') else: assert not hasattr(sampler, 'sample_indices_')