You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.5 KiB
155 lines
4.5 KiB
"""
|
|
Module containing utilities for NDFrame.sample() and .GroupBy.sample()
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
|
|
from pandas._libs import lib
|
|
|
|
from pandas.core.dtypes.generic import (
|
|
ABCDataFrame,
|
|
ABCSeries,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from pandas._typing import AxisInt
|
|
|
|
from pandas.core.generic import NDFrame
|
|
|
|
|
|
def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray:
|
|
"""
|
|
Process and validate the `weights` argument to `NDFrame.sample` and
|
|
`.GroupBy.sample`.
|
|
|
|
Returns `weights` as an ndarray[np.float64], validated except for normalizing
|
|
weights (because that must be done groupwise in groupby sampling).
|
|
"""
|
|
# If a series, align with frame
|
|
if isinstance(weights, ABCSeries):
|
|
weights = weights.reindex(obj.axes[axis])
|
|
|
|
# Strings acceptable if a dataframe and axis = 0
|
|
if isinstance(weights, str):
|
|
if isinstance(obj, ABCDataFrame):
|
|
if axis == 0:
|
|
try:
|
|
weights = obj[weights]
|
|
except KeyError as err:
|
|
raise KeyError(
|
|
"String passed to weights not a valid column"
|
|
) from err
|
|
else:
|
|
raise ValueError(
|
|
"Strings can only be passed to "
|
|
"weights when sampling from rows on "
|
|
"a DataFrame"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Strings cannot be passed as weights when sampling from a Series."
|
|
)
|
|
|
|
if isinstance(obj, ABCSeries):
|
|
func = obj._constructor
|
|
else:
|
|
func = obj._constructor_sliced
|
|
|
|
weights = func(weights, dtype="float64")._values
|
|
|
|
if len(weights) != obj.shape[axis]:
|
|
raise ValueError("Weights and axis to be sampled must be of same length")
|
|
|
|
if lib.has_infs(weights):
|
|
raise ValueError("weight vector may not include `inf` values")
|
|
|
|
if (weights < 0).any():
|
|
raise ValueError("weight vector many not include negative values")
|
|
|
|
missing = np.isnan(weights)
|
|
if missing.any():
|
|
# Don't modify weights in place
|
|
weights = weights.copy()
|
|
weights[missing] = 0
|
|
return weights
|
|
|
|
|
|
def process_sampling_size(
|
|
n: int | None, frac: float | None, replace: bool
|
|
) -> int | None:
|
|
"""
|
|
Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
|
|
`.GroupBy.sample`.
|
|
|
|
Returns None if `frac` should be used (variable sampling sizes), otherwise returns
|
|
the constant sampling size.
|
|
"""
|
|
# If no frac or n, default to n=1.
|
|
if n is None and frac is None:
|
|
n = 1
|
|
elif n is not None and frac is not None:
|
|
raise ValueError("Please enter a value for `frac` OR `n`, not both")
|
|
elif n is not None:
|
|
if n < 0:
|
|
raise ValueError(
|
|
"A negative number of rows requested. Please provide `n` >= 0."
|
|
)
|
|
if n % 1 != 0:
|
|
raise ValueError("Only integers accepted as `n` values")
|
|
else:
|
|
assert frac is not None # for mypy
|
|
if frac > 1 and not replace:
|
|
raise ValueError(
|
|
"Replace has to be set to `True` when "
|
|
"upsampling the population `frac` > 1."
|
|
)
|
|
if frac < 0:
|
|
raise ValueError(
|
|
"A negative number of rows requested. Please provide `frac` >= 0."
|
|
)
|
|
|
|
return n
|
|
|
|
|
|
def sample(
|
|
obj_len: int,
|
|
size: int,
|
|
replace: bool,
|
|
weights: np.ndarray | None,
|
|
random_state: np.random.RandomState | np.random.Generator,
|
|
) -> np.ndarray:
|
|
"""
|
|
Randomly sample `size` indices in `np.arange(obj_len)`
|
|
|
|
Parameters
|
|
----------
|
|
obj_len : int
|
|
The length of the indices being considered
|
|
size : int
|
|
The number of values to choose
|
|
replace : bool
|
|
Allow or disallow sampling of the same row more than once.
|
|
weights : np.ndarray[np.float64] or None
|
|
If None, equal probability weighting, otherwise weights according
|
|
to the vector normalized
|
|
random_state: np.random.RandomState or np.random.Generator
|
|
State used for the random sampling
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray[np.intp]
|
|
"""
|
|
if weights is not None:
|
|
weight_sum = weights.sum()
|
|
if weight_sum != 0:
|
|
weights = weights / weight_sum
|
|
else:
|
|
raise ValueError("Invalid weights: weights sum to zero")
|
|
|
|
return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
|
|
np.intp, copy=False
|
|
)
|