"""Utilities for input validation"""
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT
from __future__ import division
import warnings
from collections import OrderedDict
from numbers import Integral, Real
import numpy as np
from sklearn.base import clone
from sklearn.neighbors.base import KNeighborsMixin
from sklearn.neighbors import NearestNeighbors
from sklearn.externals import six
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.deprecation import deprecated
from ..exceptions import raise_isinstance_error
SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling',
'ensemble', 'bypass')
TARGET_KIND = ('binary', 'multiclass', 'multilabel-indicator')
[docs]def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
"""Check the objects is consistent to be a NN.
Several methods in imblearn relies on NN. Until version 0.4, these
objects can be passed at initialisation as an integer or a
KNeighborsMixin. After only KNeighborsMixin will be accepted. This
utility allows for type checking and raise if the type is wrong.
Parameters
----------
nn_name : str,
The name associated to the object to raise an error if needed.
nn_object : int or KNeighborsMixin,
The object to be checked
additional_neighbor : int, optional (default=0)
Sometimes, some algorithm need an additional neighbors.
Returns
-------
nn_object : KNeighborsMixin
The k-NN object.
"""
if isinstance(nn_object, Integral):
return NearestNeighbors(n_neighbors=nn_object + additional_neighbor)
elif isinstance(nn_object, KNeighborsMixin):
return clone(nn_object)
else:
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)
def _count_class_sample(y):
unique, counts = np.unique(y, return_counts=True)
return dict(zip(unique, counts))
def check_target_type(y, indicate_one_vs_all=False):
"""Check the target types to be conform to the current samplers.
The current samplers should be compatible with ``'binary'``,
``'multilabel-indicator'`` and ``'multiclass'`` targets only.
Parameters
----------
y : ndarray,
The array containing the target.
indicate_one_vs_all : bool, optional
Either to indicate if the targets are encoded in a one-vs-all fashion.
Returns
-------
y : ndarray,
The returned target.
is_one_vs_all : bool, optional
Indicate if the target was originally encoded in a one-vs-all fashion.
Only returned if ``indicate_multilabel=True``.
"""
type_y = type_of_target(y)
if type_y == 'multilabel-indicator':
if np.any(y.sum(axis=1) > 1):
raise ValueError(
"When 'y' corresponds to '{}', 'y' should encode the "
"multiclass (a single 1 by row).".format(type_y))
y = y.argmax(axis=1)
return (y, type_y == 'multilabel-indicator') if indicate_one_vs_all else y
def _sampling_strategy_all(y, sampling_type):
"""Returns sampling target by targeting all classes."""
target_stats = _count_class_sample(y)
if sampling_type == 'over-sampling':
n_sample_majority = max(target_stats.values())
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items()
}
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
n_sample_minority = min(target_stats.values())
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys()
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_majority(y, sampling_type):
"""Returns sampling target by targeting the majority class only."""
if sampling_type == 'over-sampling':
raise ValueError("'sampling_strategy'='majority' cannot be used with"
" over-sampler.")
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
target_stats = _count_class_sample(y)
class_majority = max(target_stats, key=target_stats.get)
n_sample_minority = min(target_stats.values())
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys() if key == class_majority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_not_majority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
majority."""
target_stats = _count_class_sample(y)
if sampling_type == 'over-sampling':
n_sample_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items() if key != class_majority
}
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
n_sample_minority = min(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys() if key != class_majority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_not_minority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
minority."""
target_stats = _count_class_sample(y)
if sampling_type == 'over-sampling':
n_sample_majority = max(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items() if key != class_minority
}
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
n_sample_minority = min(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys() if key != class_minority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_minority(y, sampling_type):
"""Returns sampling target by targeting the minority class only."""
target_stats = _count_class_sample(y)
if sampling_type == 'over-sampling':
n_sample_majority = max(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items() if key == class_minority
}
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
raise ValueError("'sampling_strategy'='minority' cannot be used with"
" under-sampler and clean-sampler.")
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_auto(y, sampling_type):
"""Returns sampling target auto for over-sampling and not-minority for
under-sampling."""
if sampling_type == 'over-sampling':
return _sampling_strategy_not_majority(y, sampling_type)
elif (sampling_type == 'under-sampling' or
sampling_type == 'clean-sampling'):
return _sampling_strategy_not_minority(y, sampling_type)
def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
"""Returns sampling target by converting the dictionary depending of the
sampling."""
target_stats = _count_class_sample(y)
# check that all keys in sampling_strategy are also in y
set_diff_sampling_strategy_target = (
set(sampling_strategy.keys()) - set(target_stats.keys()))
if len(set_diff_sampling_strategy_target) > 0:
raise ValueError("The {} target class is/are not present in the"
" data.".format(set_diff_sampling_strategy_target))
# check that there is no negative number
if any(n_samples < 0 for n_samples in sampling_strategy.values()):
raise ValueError("The number of samples in a class cannot be negative."
"'sampling_strategy' contains some negative value: {}"
.format(sampling_strategy))
sampling_strategy_ = {}
if sampling_type == 'over-sampling':
n_samples_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
for class_sample, n_samples in sampling_strategy.items():
if n_samples < target_stats[class_sample]:
raise ValueError("With over-sampling methods, the number"
" of samples in a class should be greater"
" or equal to the original number of samples."
" Originally, there is {} samples and {}"
" samples are asked.".format(
target_stats[class_sample], n_samples))
if n_samples > n_samples_majority:
warnings.warn("After over-sampling, the number of samples ({})"
" in class {} will be larger than the number of"
" samples in the majority class (class #{} ->"
" {})".format(n_samples, class_sample,
class_majority,
n_samples_majority))
sampling_strategy_[class_sample] = (
n_samples - target_stats[class_sample])
elif sampling_type == 'under-sampling':
for class_sample, n_samples in sampling_strategy.items():
if n_samples > target_stats[class_sample]:
raise ValueError("With under-sampling methods, the number of"
" samples in a class should be less or equal"
" to the original number of samples."
" Originally, there is {} samples and {}"
" samples are asked.".format(
target_stats[class_sample], n_samples))
sampling_strategy_[class_sample] = n_samples
elif sampling_type == 'clean-sampling':
# FIXME: Turn into an error in 0.6
warnings.warn("'sampling_strategy' as a dict for cleaning methods is "
"deprecated and will raise an error in version 0.6. "
"Please give a list of the classes to be targeted by the"
" sampling.", DeprecationWarning)
# clean-sampling can be more permissive since those samplers do not
# use samples
for class_sample, n_samples in sampling_strategy.items():
sampling_strategy_[class_sample] = n_samples
else:
raise NotImplementedError
return sampling_strategy_
def _sampling_strategy_list(sampling_strategy, y, sampling_type):
"""With cleaning methods, sampling_strategy can be a list to target the
class of interest."""
if sampling_type != 'clean-sampling':
raise ValueError("'sampling_strategy' cannot be a list for samplers "
"which are not cleaning methods.")
target_stats = _count_class_sample(y)
# check that all keys in sampling_strategy are also in y
set_diff_sampling_strategy_target = (
set(sampling_strategy) - set(target_stats.keys()))
if len(set_diff_sampling_strategy_target) > 0:
raise ValueError("The {} target class is/are not present in the"
" data.".format(set_diff_sampling_strategy_target))
return {
class_sample: min(target_stats.values())
for class_sample in sampling_strategy
}
def _sampling_strategy_float(sampling_strategy, y, sampling_type):
"""Take a proportion of the majority (over-sampling) or minority
(under-sampling) class in binary classification."""
type_y = type_of_target(y)
if type_y != 'binary':
raise ValueError(
'"sampling_strategy" can be a float only when the type '
'of target is binary. For multi-class, use a dict.')
target_stats = _count_class_sample(y)
if sampling_type == 'over-sampling':
n_sample_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy_ = {
key: int(n_sample_majority * sampling_strategy - value)
for (key, value) in target_stats.items() if key != class_majority
}
elif (sampling_type == 'under-sampling'):
n_sample_minority = min(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy_ = {
key: int(n_sample_minority / sampling_strategy)
for (key, value) in target_stats.items() if key != class_minority
}
else:
raise ValueError("'clean-sampling' methods do let the user "
"specify the sampling ratio.")
return sampling_strategy_
[docs]def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
"""Sampling target validation for samplers.
Checks that ``sampling_strategy`` is of consistent type and return a
dictionary containing each targeted class with its corresponding
number of sample. It is used in :class:`imblearn.base.BaseSampler`.
Parameters
----------
sampling_strategy : float, str, dict, list or callable,
Sampling information to sample the data set.
- When ``float``:
For **under-sampling methods**, it corresponds to the ratio
:math:`\\alpha_{us}` defined by :math:`N_{rM} = \\alpha_{us}
\\times N_{m}` where :math:`N_{rM}` and :math:`N_{m}` are the
number of samples in the majority class after resampling and the
number of samples in the minority class, respectively;
For **over-sampling methods**, it correspond to the ratio
:math:`\\alpha_{os}` defined by :math:`N_{rm} = \\alpha_{os}
\\times N_{m}` where :math:`N_{rm}` and :math:`N_{M}` are the
number of samples in the minority class after resampling and the
number of samples in the majority class, respectively.
.. warning::
``float`` is only available for **binary** classification. An
error is raised for multi-class classification and with cleaning
samplers.
- When ``str``, specify the class targeted by the resampling. For
**under- and over-sampling methods**, the number of samples in the
different classes will be equalized. For **cleaning methods**, the
number of samples will not be equal. Possible choices are:
``'minority'``: resample only the minority class;
``'majority'``: resample only the majority class;
``'not minority'``: resample all classes but the minority class;
``'not majority'``: resample all classes but the majority class;
``'all'``: resample all classes;
``'auto'``: for under-sampling methods, equivalent to ``'not
minority'`` and for over-sampling methods, equivalent to ``'not
majority'``.
- When ``dict``, the keys correspond to the targeted classes. The
values correspond to the desired number of samples for each targeted
class.
.. warning::
``dict`` is available for both **under- and over-sampling
methods**. An error is raised with **cleaning methods**. Use a
``list`` instead.
- When ``list``, the list contains the targeted classes. It used only
for **cleaning methods**.
.. warning::
``list`` is available for **cleaning methods**. An error is raised
with **under- and over-sampling methods**.
- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.
y : ndarray, shape (n_samples,)
The target array.
sampling_type : str,
The type of sampling. Can be either ``'over-sampling'``,
``'under-sampling'``, or ``'clean-sampling'``.
kwargs : dict, optional
Dictionary of additional keyword arguments to pass to
``sampling_strategy`` when this is a callable.
Returns
-------
sampling_strategy_converted : dict,
The converted and validated sampling target. Returns a dictionary with
the key being the class target and the value being the desired
number of samples.
"""
if sampling_type not in SAMPLING_KIND:
raise ValueError("'sampling_type' should be one of {}. Got '{}'"
" instead.".format(SAMPLING_KIND, sampling_type))
if np.unique(y).size <= 1:
raise ValueError("The target 'y' needs to have more than 1 class."
" Got {} class instead".format(np.unique(y).size))
if sampling_type in ('ensemble', 'bypass'):
return sampling_strategy
if isinstance(sampling_strategy, six.string_types):
if sampling_strategy not in SAMPLING_TARGET_KIND.keys():
raise ValueError("When 'sampling_strategy' is a string, it needs"
" to be one of {}. Got '{}' instead.".format(
SAMPLING_TARGET_KIND, sampling_strategy))
return OrderedDict(sorted(
SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items()))
elif isinstance(sampling_strategy, dict):
return OrderedDict(sorted(
_sampling_strategy_dict(sampling_strategy, y, sampling_type)
.items()))
elif isinstance(sampling_strategy, list):
return OrderedDict(sorted(
_sampling_strategy_list(sampling_strategy, y, sampling_type)
.items()))
elif isinstance(sampling_strategy, Real):
if sampling_strategy <= 0 or sampling_strategy > 1:
raise ValueError(
"When 'sampling_strategy' is a float, it should be "
"in the range (0, 1]. Got {} instead."
.format(sampling_strategy))
return OrderedDict(sorted(
_sampling_strategy_float(sampling_strategy, y, sampling_type)
.items()))
elif callable(sampling_strategy):
sampling_strategy_ = sampling_strategy(y, **kwargs)
return OrderedDict(sorted(
_sampling_strategy_dict(sampling_strategy_, y, sampling_type)
.items()))
SAMPLING_TARGET_KIND = {
'minority': _sampling_strategy_minority,
'majority': _sampling_strategy_majority,
'not minority': _sampling_strategy_not_minority,
'not majority': _sampling_strategy_not_majority,
'all': _sampling_strategy_all,
'auto': _sampling_strategy_auto
}
[docs]@deprecated("imblearn.utils.check_ratio was deprecated in favor of "
"imblearn.utils.check_sampling_strategy in 0.4. It will be "
"removed in 0.6.")
def check_ratio(ratio, y, sampling_type, **kwargs):
"""Sampling target validation for samplers.
Checks ratio for consistent type and return a dictionary
containing each targeted class with its corresponding number of
sample.
.. deprecated:: 0.4
This function is deprecated in favor of
:func:`imblearn.utils.check_sampling_strategy`. It will be removed in
0.6.
Parameters
----------
ratio : str, dict or callable,
Ratio to use for resampling the data set.
- If ``str``, has to be one of: (i) ``'minority'``: resample the
minority class; (ii) ``'majority'``: resample the majority class,
(iii) ``'not minority'``: resample all classes apart of the minority
class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``:
correspond to ``'all'`` with for over-sampling methods and ``'not
minority'`` for under-sampling methods. The classes targeted will be
over-sampled or under-sampled to achieve an equal number of sample
with the majority or minority class.
- If ``dict``, the keys correspond to the targeted classes. The values
correspond to the desired number of samples.
- If callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples.
y : ndarray, shape (n_samples,)
The target array.
sampling_type : str,
The type of sampling. Can be either ``'over-sampling'`` or
``'under-sampling'``.
kwargs : dict, optional
Dictionary of additional keyword arguments to pass to ``ratio``.
Returns
-------
ratio_converted : dict,
The converted and validated ratio. Returns a dictionary with
the key being the class target and the value being the desired
number of samples.
"""
return check_sampling_strategy(ratio, y, sampling_type, **kwargs)