Source code for imblearn.ensemble.base

"""
Base class for the ensemble method.
"""
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT

import numpy as np

from sklearn.preprocessing import label_binarize

from ..base import BaseSampler
from ..utils import check_sampling_strategy


class BaseEnsembleSampler(BaseSampler):
    """Base class for ensemble algorithms.

    Warning: This class should not be used directly. Use the derive classes
    instead.
    """

    _sampling_type = 'ensemble'

    def fit_resample(self, X, y):
        """Resample the dataset.

        Parameters
        ----------
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            Matrix containing the data which have to be sampled.

        y : array-like, shape (n_samples,)
            Corresponding label for each sample in X.

        Returns
        -------
        X_resampled : {ndarray, sparse matrix}, shape \
(n_subset, n_samples_new, n_features)
            The array containing the resampled data.

        y_resampled : ndarray, shape (n_subset, n_samples_new)
            The corresponding label of `X_resampled`

        """
        self._deprecate_ratio()
        # Ensemble are a bit specific since they are returning an array of
        # resampled arrays.
        X, y, binarize_y = self._check_X_y(X, y)

        self.sampling_strategy_ = check_sampling_strategy(
            self.sampling_strategy, y, self._sampling_type)

        output = self._fit_resample(X, y)

        if binarize_y:
            y_resampled = output[1]
            classes = np.unique(y)
            y_resampled_encoded = np.array(
                [label_binarize(batch_y, classes) for batch_y in y_resampled])
            if len(output) == 2:
                return output[0], y_resampled_encoded
            return output[0], y_resampled_encoded, output[2]
        return output