You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

582 lines
16 KiB

from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Literal,
final,
)
import numpy as np
from pandas.core.dtypes.common import (
is_integer,
is_list_like,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCIndex,
)
from pandas.core.dtypes.missing import (
isna,
remove_na_arraylike,
)
from pandas.io.formats.printing import pprint_thing
from pandas.plotting._matplotlib.core import (
LinePlot,
MPLPlot,
)
from pandas.plotting._matplotlib.groupby import (
create_iter_data_given_by,
reformat_hist_y_given_by,
)
from pandas.plotting._matplotlib.misc import unpack_single_str_list
from pandas.plotting._matplotlib.tools import (
create_subplots,
flatten_axes,
maybe_adjust_figure,
set_ticks_props,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pandas._typing import PlottingOrientation
from pandas import (
DataFrame,
Series,
)
class HistPlot(LinePlot):
@property
def _kind(self) -> Literal["hist", "kde"]:
return "hist"
def __init__(
self,
data,
bins: int | np.ndarray | list[np.ndarray] = 10,
bottom: int | np.ndarray = 0,
*,
range=None,
weights=None,
**kwargs,
) -> None:
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom
self._bin_range = range
self.weights = weights
self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
self.bins = self._adjust_bins(bins)
def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
if is_integer(bins):
if self.by is not None:
by_modified = unpack_single_str_list(self.by)
grouped = self.data.groupby(by_modified)[self.columns]
bins = [self._calculate_bins(group, bins) for key, group in grouped]
else:
bins = self._calculate_bins(self.data, bins)
return bins
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects(copy=False)._get_numeric_data()
values = np.ravel(nd_values)
values = values[~isna(values)]
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
return bins
# error: Signature of "_plot" incompatible with supertype "LinePlot"
@classmethod
def _plot( # type: ignore[override]
cls,
ax: Axes,
y: np.ndarray,
style=None,
bottom: int | np.ndarray = 0,
column_num: int = 0,
stacking_id=None,
*,
bins,
**kwds,
):
if column_num == 0:
cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
base = np.zeros(len(bins) - 1)
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
# ignore style
n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
cls._update_stacker(ax, stacking_id, n)
return patches
def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors()
stacking_id = self._get_stacking_id()
# Re-create iterated data if `by` is assigned by users
data = (
create_iter_data_given_by(self.data, self._kind)
if self.by is not None
else self.data
)
# error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
# type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
ax = self._get_ax(i)
kwds = self.kwds.copy()
if self.color is not None:
kwds["color"] = self.color
label = pprint_thing(label)
label = self._mark_right_label(label, index=i)
kwds["label"] = label
style, kwds = self._apply_style_colors(colors, kwds, i, label)
if style is not None:
kwds["style"] = style
self._make_plot_keywords(kwds, y)
# the bins is multi-dimension array now and each plot need only 1-d and
# when by is applied, label should be columns that are grouped
if self.by is not None:
kwds["bins"] = kwds["bins"][i]
kwds["label"] = self.columns
kwds.pop("color")
if self.weights is not None:
kwds["weights"] = type(self)._get_column_weights(self.weights, i, y)
y = reformat_hist_y_given_by(y, self.by)
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
# when by is applied, show title for subplots to know which group it is
if self.by is not None:
ax.set_title(pprint_thing(label))
self._append_legend_handles_labels(artists[0], label)
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
"""merge BoxPlot/KdePlot properties to passed kwds"""
# y is required for KdePlot
kwds["bottom"] = self.bottom
kwds["bins"] = self.bins
@final
@staticmethod
def _get_column_weights(weights, i: int, y):
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
# and each sub-array (10,) will be called in each iteration. If users only
# provide 1D array, we assume the same weights is used for all iterations
if weights is not None:
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
try:
weights = weights[:, i]
except IndexError as err:
raise ValueError(
"weights must have the same shape as data, "
"or be a single column"
) from err
weights = weights[~isna(y)]
return weights
def _post_plot_logic(self, ax: Axes, data) -> None:
if self.orientation == "horizontal":
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
# type "Hashable"; expected "str"
ax.set_xlabel(
"Frequency"
if self.xlabel is None
else self.xlabel # type: ignore[arg-type]
)
ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
else:
ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
ax.set_ylabel(
"Frequency"
if self.ylabel is None
else self.ylabel # type: ignore[arg-type]
)
@property
def orientation(self) -> PlottingOrientation:
if self.kwds.get("orientation", None) == "horizontal":
return "horizontal"
else:
return "vertical"
class KdePlot(HistPlot):
@property
def _kind(self) -> Literal["kde"]:
return "kde"
@property
def orientation(self) -> Literal["vertical"]:
return "vertical"
def __init__(
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
) -> None:
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
self.bw_method = bw_method
self.ind = ind
self.weights = weights
@staticmethod
def _get_ind(y: np.ndarray, ind):
if ind is None:
# np.nanmax() and np.nanmin() ignores the missing values
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
1000,
)
elif is_integer(ind):
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
ind,
)
return ind
@classmethod
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
def _plot( # type: ignore[override]
cls,
ax: Axes,
y: np.ndarray,
style=None,
bw_method=None,
ind=None,
column_num=None,
stacking_id: int | None = None,
**kwds,
):
from scipy.stats import gaussian_kde
y = remove_na_arraylike(y)
gkde = gaussian_kde(y, bw_method=bw_method)
y = gkde.evaluate(ind)
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
return lines
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
kwds["bw_method"] = self.bw_method
kwds["ind"] = type(self)._get_ind(y, ind=self.ind)
def _post_plot_logic(self, ax: Axes, data) -> None:
ax.set_ylabel("Density")
def _grouped_plot(
plotf,
data: Series | DataFrame,
column=None,
by=None,
numeric_only: bool = True,
figsize: tuple[float, float] | None = None,
sharex: bool = True,
sharey: bool = True,
layout=None,
rot: float = 0,
ax=None,
**kwargs,
):
# error: Non-overlapping equality check (left operand type: "Optional[Tuple[float,
# float]]", right operand type: "Literal['default']")
if figsize == "default": # type: ignore[comparison-overlap]
# allowed to specify mpl default with 'default'
raise ValueError(
"figsize='default' is no longer supported. "
"Specify figure size by tuple instead"
)
grouped = data.groupby(by)
if column is not None:
grouped = grouped[column]
naxes = len(grouped)
fig, axes = create_subplots(
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
)
_axes = flatten_axes(axes)
for i, (key, group) in enumerate(grouped):
ax = _axes[i]
if numeric_only and isinstance(group, ABCDataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
ax.set_title(pprint_thing(key))
return fig, axes
def _grouped_hist(
data: Series | DataFrame,
column=None,
by=None,
ax=None,
bins: int = 50,
figsize: tuple[float, float] | None = None,
layout=None,
sharex: bool = False,
sharey: bool = False,
rot: float = 90,
grid: bool = True,
xlabelsize: int | None = None,
xrot=None,
ylabelsize: int | None = None,
yrot=None,
legend: bool = False,
**kwargs,
):
"""
Grouped histogram
Parameters
----------
data : Series/DataFrame
column : object, optional
by : object, optional
ax : axes, optional
bins : int, default 50
figsize : tuple, optional
layout : optional
sharex : bool, default False
sharey : bool, default False
rot : float, default 90
grid : bool, default True
legend: : bool, default False
kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
Returns
-------
collection of Matplotlib Axes
"""
if legend:
assert "label" not in kwargs
if data.ndim == 1:
kwargs["label"] = data.name
elif column is None:
kwargs["label"] = data.columns
else:
kwargs["label"] = column
def plot_group(group, ax) -> None:
ax.hist(group.dropna().values, bins=bins, **kwargs)
if legend:
ax.legend()
if xrot is None:
xrot = rot
fig, axes = _grouped_plot(
plot_group,
data,
column=column,
by=by,
sharex=sharex,
sharey=sharey,
ax=ax,
figsize=figsize,
layout=layout,
rot=rot,
)
set_ticks_props(
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
)
maybe_adjust_figure(
fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
)
return axes
def hist_series(
self: Series,
by=None,
ax=None,
grid: bool = True,
xlabelsize: int | None = None,
xrot=None,
ylabelsize: int | None = None,
yrot=None,
figsize: tuple[float, float] | None = None,
bins: int = 10,
legend: bool = False,
**kwds,
):
import matplotlib.pyplot as plt
if legend and "label" in kwds:
raise ValueError("Cannot use both legend and label")
if by is None:
if kwds.get("layout", None) is not None:
raise ValueError("The 'layout' keyword is not supported when 'by' is None")
# hack until the plotting interface is a bit more unified
fig = kwds.pop(
"figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
)
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
fig.set_size_inches(*figsize, forward=True)
if ax is None:
ax = fig.gca()
elif ax.get_figure() != fig:
raise AssertionError("passed axis not bound to passed figure")
values = self.dropna().values
if legend:
kwds["label"] = self.name
ax.hist(values, bins=bins, **kwds)
if legend:
ax.legend()
ax.grid(grid)
axes = np.array([ax])
# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
set_ticks_props(
axes, # type: ignore[arg-type]
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
)
else:
if "figure" in kwds:
raise ValueError(
"Cannot pass 'figure' when using the "
"'by' argument, since a new 'Figure' instance will be created"
)
axes = _grouped_hist(
self,
by=by,
ax=ax,
grid=grid,
figsize=figsize,
bins=bins,
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
legend=legend,
**kwds,
)
if hasattr(axes, "ndim"):
if axes.ndim == 1 and len(axes) == 1:
return axes[0]
return axes
def hist_frame(
data: DataFrame,
column=None,
by=None,
grid: bool = True,
xlabelsize: int | None = None,
xrot=None,
ylabelsize: int | None = None,
yrot=None,
ax=None,
sharex: bool = False,
sharey: bool = False,
figsize: tuple[float, float] | None = None,
layout=None,
bins: int = 10,
legend: bool = False,
**kwds,
):
if legend and "label" in kwds:
raise ValueError("Cannot use both legend and label")
if by is not None:
axes = _grouped_hist(
data,
column=column,
by=by,
ax=ax,
grid=grid,
figsize=figsize,
sharex=sharex,
sharey=sharey,
layout=layout,
bins=bins,
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
legend=legend,
**kwds,
)
return axes
if column is not None:
if not isinstance(column, (list, np.ndarray, ABCIndex)):
column = [column]
data = data[column]
# GH32590
data = data.select_dtypes(
include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
)
naxes = len(data.columns)
if naxes == 0:
raise ValueError(
"hist method requires numerical or datetime columns, nothing to plot."
)
fig, axes = create_subplots(
naxes=naxes,
ax=ax,
squeeze=False,
sharex=sharex,
sharey=sharey,
figsize=figsize,
layout=layout,
)
_axes = flatten_axes(axes)
can_set_label = "label" not in kwds
for i, col in enumerate(data.columns):
ax = _axes[i]
if legend and can_set_label:
kwds["label"] = col
ax.hist(data[col].dropna().values, bins=bins, **kwds)
ax.set_title(col)
ax.grid(grid)
if legend:
ax.legend()
set_ticks_props(
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
)
maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
return axes