Source code for imblearn.ensemble._forest

"""Forest classifiers trained on balanced boostrasp samples."""

# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT

import numbers
from warnings import warn
from copy import deepcopy

import numpy as np
from numpy import float32 as DTYPE
from numpy import float64 as DOUBLE
from scipy.sparse import issparse

from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble.base import _set_random_states
from sklearn.ensemble.forest import _parallel_build_trees
from sklearn.exceptions import DataConversionWarning
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.utils import safe_indexing

from ..pipeline import make_pipeline
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution
from ..utils._docstring import _random_state_docstring

MAX_INT = np.iinfo(np.int32).max


def _local_parallel_build_trees(sampler, tree, forest, X, y, sample_weight,
                                tree_idx, n_trees, verbose=0,
                                class_weight=None):
    # resample before to fit the tree
    X_resampled, y_resampled = sampler.fit_sample(X, y)
    if sample_weight is not None:
        sample_weight = safe_indexing(sample_weight, sampler.sample_indices_)
    tree = _parallel_build_trees(tree, forest, X_resampled, y_resampled,
                                 sample_weight, tree_idx, n_trees,
                                 verbose=verbose, class_weight=class_weight)
    return sampler, tree


[docs]@Substitution( sampling_strategy=BaseUnderSampler._sampling_strategy_docstring, random_state=_random_state_docstring) class BalancedRandomForestClassifier(RandomForestClassifier): """A balanced random forest classifier. A balanced random forest randomly under-samples each boostrap sample to balance it. Read more in the :ref:`User Guide <forest>`. Parameters ---------- n_estimators : integer, optional (default=100) The number of trees in the forest. criterion : string, optional (default="gini") The function to measure the quality of a split. Supported criteria are "gini" for the Gini impurity and "entropy" for the information gain. Note: this parameter is tree-specific. max_depth : integer or None, optional (default=None) The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples. min_samples_split : int, float, optional (default=2) The minimum number of samples required to split an internal node: - If int, then consider `min_samples_split` as the minimum number. - If float, then `min_samples_split` is a percentage and `ceil(min_samples_split * n_samples)` are the minimum number of samples for each split. min_samples_leaf : int, float, optional (default=1) The minimum number of samples required to be at a leaf node: - If int, then consider ``min_samples_leaf`` as the minimum number. - If float, then ``min_samples_leaf`` is a fraction and `ceil(min_samples_leaf * n_samples)` are the minimum number of samples for each node. min_weight_fraction_leaf : float, optional (default=0.) The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided. max_features : int, float, string or None, optional (default="auto") The number of features to consider when looking for the best split: - If int, then consider `max_features` features at each split. - If float, then `max_features` is a percentage and `int(max_features * n_features)` features are considered at each split. - If "auto", then `max_features=sqrt(n_features)`. - If "sqrt", then `max_features=sqrt(n_features)` (same as "auto"). - If "log2", then `max_features=log2(n_features)`. - If None, then `max_features=n_features`. Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than ``max_features`` features. max_leaf_nodes : int or None, optional (default=None) Grow trees with ``max_leaf_nodes`` in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. min_impurity_decrease : float, optional (default=0.) A node will be split if this split induces a decrease of the impurity greater than or equal to this value. The weighted impurity decrease equation is the following:: N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity) where ``N`` is the total number of samples, ``N_t`` is the number of samples at the current node, ``N_t_L`` is the number of samples in the left child, and ``N_t_R`` is the number of samples in the right child. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, if ``sample_weight`` is passed. bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. oob_score : bool (default=False) Whether to use out-of-bag samples to estimate the generalization accuracy. {sampling_strategy} replacement : bool, optional (default=False) Whether or not to sample randomly with replacement or not. n_jobs : int, optional (default=1) The number of jobs to run in parallel for both `fit` and `predict`. If -1, then the number of jobs is set to the number of cores. {random_state} verbose : int, optional (default=0) Controls the verbosity of the tree building process. warm_start : bool, optional (default=False) When set to ``True``, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new forest. class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional (default=None) Weights associated with classes in the form dictionary with the key being the class_label and the value the weight. If not given, all classes are supposed to have weight one. For multi-output problems, a list of dicts can be provided in the same order as the columns of y. Note that for multioutput (including multilabel) weights should be defined for each class of every column in its own dict. For example, for four-class multilabel classification weights should be [{{0: 1, 1: 1}}, {{0: 1, 1: 5}}, {{0: 1, 1: 1}}, {{0: 1, 1: 1}}] instead of [{{1:1}}, {{2:5}}, {{3:1}}, {{4:1}}]. The "balanced" mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as ``n_samples / (n_classes * np.bincount(y))`` The "balanced_subsample" mode is the same as "balanced" except that weights are computed based on the bootstrap sample for every tree grown. For multi-output, the weights of each column of y will be multiplied. Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. Attributes ---------- estimators_ : list of DecisionTreeClassifier The collection of fitted sub-estimators. samplers_ : list of RandomUnderSampler The collection of fitted samplers. pipelines_ : list of Pipeline. The collection of fitted pipelines (samplers + trees). classes_ : ndaray, shape (n_classes,) or a list of such arrays The classes labels (single output problem), or a list of arrays of class labels (multi-output problem). n_classes_ : int or list The number of classes (single output problem), or a list containing the number of classes for each output (multi-output problem). n_features_ : int The number of features when ``fit`` is performed. n_outputs_ : int The number of outputs when ``fit`` is performed. feature_importances_ : ndarray, shape (n_features,) The feature importances (the higher, the more important the feature). oob_score_ : float Score of the training dataset obtained using an out-of-bag estimate. oob_decision_function_ : ndarray, shape (n_samples, n_classes) Decision function computed with out-of-bag estimate on the training set. If n_estimators is small it might be possible that a data point was never left out during the bootstrap. In this case, `oob_decision_function_` might contain NaN. See also -------- BalancedBaggingClassifier, EasyEnsembleClassifier, RUSBoostClassifier References ---------- .. [1] Chen, Chao, Andy Liaw, and Leo Breiman. "Using random forest to learn imbalanced data." University of California, Berkeley 110 (2004): 1-12. Examples -------- >>> from imblearn.ensemble import BalancedRandomForestClassifier >>> from sklearn.datasets import make_classification >>> >>> X, y = make_classification(n_samples=1000, n_classes=3, ... n_informative=4, weights=[0.2, 0.3, 0.5], ... random_state=0) >>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0) >>> clf.fit(X, y) # doctest: +ELLIPSIS BalancedRandomForestClassifier(...) >>> print(clf.feature_importances_) [ 0.21506735 0.0104961 0.00706549 0.17414694 0.00556422 0.00704686 0.19779549 0.01865445 0.00608294 0.00490484 0.00866699 0.00251414 0.00339721 0.01174379 0.09380596 0.05049964 0.0033278 0.01008566 0.15534173 0.01379241] >>> print(clf.predict([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) [1] """
[docs] def __init__(self, n_estimators=100, criterion="gini", max_depth=None, min_samples_split=2, min_samples_leaf=2, min_weight_fraction_leaf=0., max_features="auto", max_leaf_nodes=None, min_impurity_decrease=0., bootstrap=True, oob_score=False, sampling_strategy='auto', replacement=False, n_jobs=1, random_state=None, verbose=0, warm_start=False, class_weight=None): super(BalancedRandomForestClassifier, self).__init__( criterion=criterion, max_depth=max_depth, n_estimators=n_estimators, bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, class_weight=class_weight, min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease) self.sampling_strategy = sampling_strategy self.replacement = replacement
def _validate_estimator(self, default=DecisionTreeClassifier()): """Check the estimator and the n_estimator attribute, set the `base_estimator_` attribute.""" if not isinstance(self.n_estimators, (numbers.Integral, np.integer)): raise ValueError("n_estimators must be an integer, " "got {0}.".format(type(self.n_estimators))) if self.n_estimators <= 0: raise ValueError("n_estimators must be greater than zero, " "got {0}.".format(self.n_estimators)) if self.base_estimator is not None: self.base_estimator_ = clone(self.base_estimator) else: self.base_estimator_ = clone(default) self.base_sampler_ = RandomUnderSampler( sampling_strategy=self.sampling_strategy, replacement=self.replacement) def _make_sampler_estimator(self, random_state=None): """Make and configure a copy of the `base_estimator_` attribute. Warning: This method should be used to properly instantiate new sub-estimators. """ estimator = clone(self.base_estimator_) estimator.set_params(**dict((p, getattr(self, p)) for p in self.estimator_params)) sampler = clone(self.base_sampler_) if random_state is not None: _set_random_states(estimator, random_state) _set_random_states(sampler, random_state) return estimator, sampler
[docs] def fit(self, X, y, sample_weight=None): """Build a forest of trees from the training set (X, y). Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) The training input samples. Internally, its dtype will be converted to ``dtype=np.float32``. If a sparse matrix is provided, it will be converted into a sparse ``csc_matrix``. y : array-like, shape (n_samples,) or (n_samples, n_outputs) The target values (class labels in classification, real numbers in regression). sample_weight : array-like, shape (n_samples,) Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are ignored while searching for a split in each node. In the case of classification, splits are also ignored if they would result in any single class carrying a negative weight in either child node. Returns ------- self : object """ # Validate or convert input data X = check_array(X, accept_sparse="csc", dtype=DTYPE) y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) if sample_weight is not None: sample_weight = check_array(sample_weight, ensure_2d=False) if issparse(X): # Pre-sort indices to avoid that each individual tree of the # ensemble sorts the indices. X.sort_indices() # Remap output _, self.n_features_ = X.shape y = np.atleast_1d(y) if y.ndim == 2 and y.shape[1] == 1: warn("A column-vector y was passed when a 1d array was" " expected. Please change the shape of y to " "(n_samples,), for example using ravel().", DataConversionWarning, stacklevel=2) if y.ndim == 1: # reshape is necessary to preserve the data contiguity against vs # [:, np.newaxis] that does not. y = np.reshape(y, (-1, 1)) self.n_outputs_ = y.shape[1] y, expanded_class_weight = self._validate_y_class_weight(y) if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) if expanded_class_weight is not None: if sample_weight is not None: sample_weight = sample_weight * expanded_class_weight else: sample_weight = expanded_class_weight # Check parameters self._validate_estimator() if not self.bootstrap and self.oob_score: raise ValueError("Out of bag estimation only available" " if bootstrap=True") random_state = check_random_state(self.random_state) if not self.warm_start or not hasattr(self, "estimators_"): # Free allocated memory, if any self.estimators_ = [] self.samplers_ = [] self.pipelines_ = [] n_more_estimators = self.n_estimators - len(self.estimators_) if n_more_estimators < 0: raise ValueError('n_estimators=%d must be larger or equal to ' 'len(estimators_)=%d when warm_start==True' % (self.n_estimators, len(self.estimators_))) elif n_more_estimators == 0: warn("Warm-start fitting without increasing n_estimators does not " "fit new trees.") else: if self.warm_start and len(self.estimators_) > 0: # We draw from the random state to get the random state we # would have got if we hadn't used a warm_start. random_state.randint(MAX_INT, size=len(self.estimators_)) trees = [] samplers = [] for _ in range(n_more_estimators): tree, sampler = self._make_sampler_estimator( random_state=random_state) trees.append(tree) samplers.append(sampler) # Parallel loop: we prefer the threading backend as the Cython code # for fitting the trees is internally releasing the Python GIL # making threading more efficient than multiprocessing in # that case. However, we respect any parallel_backend contexts set # at a higher level, since correctness does not rely on using # threads. samplers_trees = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, prefer="threads")( delayed(_local_parallel_build_trees)( s, t, self, X, y, sample_weight, i, len(trees), verbose=self.verbose, class_weight=self.class_weight) for i, (s, t) in enumerate(zip(samplers, trees))) samplers, trees = zip(*samplers_trees) # Collect newly grown trees self.estimators_.extend(trees) self.samplers_.extend(samplers) # Create pipeline with the fitted samplers and trees self.pipelines_.extend([make_pipeline(deepcopy(s), deepcopy(t)) for s, t in zip(samplers, trees)]) if self.oob_score: self._set_oob_score(X, y) # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] return self