You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
582 lines
16 KiB
582 lines
16 KiB
from __future__ import annotations
|
|
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Literal,
|
|
final,
|
|
)
|
|
|
|
import numpy as np
|
|
|
|
from pandas.core.dtypes.common import (
|
|
is_integer,
|
|
is_list_like,
|
|
)
|
|
from pandas.core.dtypes.generic import (
|
|
ABCDataFrame,
|
|
ABCIndex,
|
|
)
|
|
from pandas.core.dtypes.missing import (
|
|
isna,
|
|
remove_na_arraylike,
|
|
)
|
|
|
|
from pandas.io.formats.printing import pprint_thing
|
|
from pandas.plotting._matplotlib.core import (
|
|
LinePlot,
|
|
MPLPlot,
|
|
)
|
|
from pandas.plotting._matplotlib.groupby import (
|
|
create_iter_data_given_by,
|
|
reformat_hist_y_given_by,
|
|
)
|
|
from pandas.plotting._matplotlib.misc import unpack_single_str_list
|
|
from pandas.plotting._matplotlib.tools import (
|
|
create_subplots,
|
|
flatten_axes,
|
|
maybe_adjust_figure,
|
|
set_ticks_props,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from matplotlib.axes import Axes
|
|
from matplotlib.figure import Figure
|
|
|
|
from pandas._typing import PlottingOrientation
|
|
|
|
from pandas import (
|
|
DataFrame,
|
|
Series,
|
|
)
|
|
|
|
|
|
class HistPlot(LinePlot):
|
|
@property
|
|
def _kind(self) -> Literal["hist", "kde"]:
|
|
return "hist"
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
bins: int | np.ndarray | list[np.ndarray] = 10,
|
|
bottom: int | np.ndarray = 0,
|
|
*,
|
|
range=None,
|
|
weights=None,
|
|
**kwargs,
|
|
) -> None:
|
|
if is_list_like(bottom):
|
|
bottom = np.array(bottom)
|
|
self.bottom = bottom
|
|
|
|
self._bin_range = range
|
|
self.weights = weights
|
|
|
|
self.xlabel = kwargs.get("xlabel")
|
|
self.ylabel = kwargs.get("ylabel")
|
|
# Do not call LinePlot.__init__ which may fill nan
|
|
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
|
|
|
|
self.bins = self._adjust_bins(bins)
|
|
|
|
def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
|
|
if is_integer(bins):
|
|
if self.by is not None:
|
|
by_modified = unpack_single_str_list(self.by)
|
|
grouped = self.data.groupby(by_modified)[self.columns]
|
|
bins = [self._calculate_bins(group, bins) for key, group in grouped]
|
|
else:
|
|
bins = self._calculate_bins(self.data, bins)
|
|
return bins
|
|
|
|
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
|
|
"""Calculate bins given data"""
|
|
nd_values = data.infer_objects(copy=False)._get_numeric_data()
|
|
values = np.ravel(nd_values)
|
|
values = values[~isna(values)]
|
|
|
|
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
|
|
return bins
|
|
|
|
# error: Signature of "_plot" incompatible with supertype "LinePlot"
|
|
@classmethod
|
|
def _plot( # type: ignore[override]
|
|
cls,
|
|
ax: Axes,
|
|
y: np.ndarray,
|
|
style=None,
|
|
bottom: int | np.ndarray = 0,
|
|
column_num: int = 0,
|
|
stacking_id=None,
|
|
*,
|
|
bins,
|
|
**kwds,
|
|
):
|
|
if column_num == 0:
|
|
cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
|
|
|
|
base = np.zeros(len(bins) - 1)
|
|
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
|
|
# ignore style
|
|
n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
|
|
cls._update_stacker(ax, stacking_id, n)
|
|
return patches
|
|
|
|
def _make_plot(self, fig: Figure) -> None:
|
|
colors = self._get_colors()
|
|
stacking_id = self._get_stacking_id()
|
|
|
|
# Re-create iterated data if `by` is assigned by users
|
|
data = (
|
|
create_iter_data_given_by(self.data, self._kind)
|
|
if self.by is not None
|
|
else self.data
|
|
)
|
|
|
|
# error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
|
|
# type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
|
|
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
|
|
ax = self._get_ax(i)
|
|
|
|
kwds = self.kwds.copy()
|
|
if self.color is not None:
|
|
kwds["color"] = self.color
|
|
|
|
label = pprint_thing(label)
|
|
label = self._mark_right_label(label, index=i)
|
|
kwds["label"] = label
|
|
|
|
style, kwds = self._apply_style_colors(colors, kwds, i, label)
|
|
if style is not None:
|
|
kwds["style"] = style
|
|
|
|
self._make_plot_keywords(kwds, y)
|
|
|
|
# the bins is multi-dimension array now and each plot need only 1-d and
|
|
# when by is applied, label should be columns that are grouped
|
|
if self.by is not None:
|
|
kwds["bins"] = kwds["bins"][i]
|
|
kwds["label"] = self.columns
|
|
kwds.pop("color")
|
|
|
|
if self.weights is not None:
|
|
kwds["weights"] = type(self)._get_column_weights(self.weights, i, y)
|
|
|
|
y = reformat_hist_y_given_by(y, self.by)
|
|
|
|
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
|
|
|
|
# when by is applied, show title for subplots to know which group it is
|
|
if self.by is not None:
|
|
ax.set_title(pprint_thing(label))
|
|
|
|
self._append_legend_handles_labels(artists[0], label)
|
|
|
|
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
|
|
"""merge BoxPlot/KdePlot properties to passed kwds"""
|
|
# y is required for KdePlot
|
|
kwds["bottom"] = self.bottom
|
|
kwds["bins"] = self.bins
|
|
|
|
@final
|
|
@staticmethod
|
|
def _get_column_weights(weights, i: int, y):
|
|
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
|
|
# and each sub-array (10,) will be called in each iteration. If users only
|
|
# provide 1D array, we assume the same weights is used for all iterations
|
|
if weights is not None:
|
|
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
|
|
try:
|
|
weights = weights[:, i]
|
|
except IndexError as err:
|
|
raise ValueError(
|
|
"weights must have the same shape as data, "
|
|
"or be a single column"
|
|
) from err
|
|
weights = weights[~isna(y)]
|
|
return weights
|
|
|
|
def _post_plot_logic(self, ax: Axes, data) -> None:
|
|
if self.orientation == "horizontal":
|
|
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
|
|
# type "Hashable"; expected "str"
|
|
ax.set_xlabel(
|
|
"Frequency"
|
|
if self.xlabel is None
|
|
else self.xlabel # type: ignore[arg-type]
|
|
)
|
|
ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
|
|
else:
|
|
ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
|
|
ax.set_ylabel(
|
|
"Frequency"
|
|
if self.ylabel is None
|
|
else self.ylabel # type: ignore[arg-type]
|
|
)
|
|
|
|
@property
|
|
def orientation(self) -> PlottingOrientation:
|
|
if self.kwds.get("orientation", None) == "horizontal":
|
|
return "horizontal"
|
|
else:
|
|
return "vertical"
|
|
|
|
|
|
class KdePlot(HistPlot):
|
|
@property
|
|
def _kind(self) -> Literal["kde"]:
|
|
return "kde"
|
|
|
|
@property
|
|
def orientation(self) -> Literal["vertical"]:
|
|
return "vertical"
|
|
|
|
def __init__(
|
|
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
|
|
) -> None:
|
|
# Do not call LinePlot.__init__ which may fill nan
|
|
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
|
|
self.bw_method = bw_method
|
|
self.ind = ind
|
|
self.weights = weights
|
|
|
|
@staticmethod
|
|
def _get_ind(y: np.ndarray, ind):
|
|
if ind is None:
|
|
# np.nanmax() and np.nanmin() ignores the missing values
|
|
sample_range = np.nanmax(y) - np.nanmin(y)
|
|
ind = np.linspace(
|
|
np.nanmin(y) - 0.5 * sample_range,
|
|
np.nanmax(y) + 0.5 * sample_range,
|
|
1000,
|
|
)
|
|
elif is_integer(ind):
|
|
sample_range = np.nanmax(y) - np.nanmin(y)
|
|
ind = np.linspace(
|
|
np.nanmin(y) - 0.5 * sample_range,
|
|
np.nanmax(y) + 0.5 * sample_range,
|
|
ind,
|
|
)
|
|
return ind
|
|
|
|
@classmethod
|
|
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
|
|
def _plot( # type: ignore[override]
|
|
cls,
|
|
ax: Axes,
|
|
y: np.ndarray,
|
|
style=None,
|
|
bw_method=None,
|
|
ind=None,
|
|
column_num=None,
|
|
stacking_id: int | None = None,
|
|
**kwds,
|
|
):
|
|
from scipy.stats import gaussian_kde
|
|
|
|
y = remove_na_arraylike(y)
|
|
gkde = gaussian_kde(y, bw_method=bw_method)
|
|
|
|
y = gkde.evaluate(ind)
|
|
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
|
|
return lines
|
|
|
|
def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
|
|
kwds["bw_method"] = self.bw_method
|
|
kwds["ind"] = type(self)._get_ind(y, ind=self.ind)
|
|
|
|
def _post_plot_logic(self, ax: Axes, data) -> None:
|
|
ax.set_ylabel("Density")
|
|
|
|
|
|
def _grouped_plot(
|
|
plotf,
|
|
data: Series | DataFrame,
|
|
column=None,
|
|
by=None,
|
|
numeric_only: bool = True,
|
|
figsize: tuple[float, float] | None = None,
|
|
sharex: bool = True,
|
|
sharey: bool = True,
|
|
layout=None,
|
|
rot: float = 0,
|
|
ax=None,
|
|
**kwargs,
|
|
):
|
|
# error: Non-overlapping equality check (left operand type: "Optional[Tuple[float,
|
|
# float]]", right operand type: "Literal['default']")
|
|
if figsize == "default": # type: ignore[comparison-overlap]
|
|
# allowed to specify mpl default with 'default'
|
|
raise ValueError(
|
|
"figsize='default' is no longer supported. "
|
|
"Specify figure size by tuple instead"
|
|
)
|
|
|
|
grouped = data.groupby(by)
|
|
if column is not None:
|
|
grouped = grouped[column]
|
|
|
|
naxes = len(grouped)
|
|
fig, axes = create_subplots(
|
|
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
|
|
)
|
|
|
|
_axes = flatten_axes(axes)
|
|
|
|
for i, (key, group) in enumerate(grouped):
|
|
ax = _axes[i]
|
|
if numeric_only and isinstance(group, ABCDataFrame):
|
|
group = group._get_numeric_data()
|
|
plotf(group, ax, **kwargs)
|
|
ax.set_title(pprint_thing(key))
|
|
|
|
return fig, axes
|
|
|
|
|
|
def _grouped_hist(
|
|
data: Series | DataFrame,
|
|
column=None,
|
|
by=None,
|
|
ax=None,
|
|
bins: int = 50,
|
|
figsize: tuple[float, float] | None = None,
|
|
layout=None,
|
|
sharex: bool = False,
|
|
sharey: bool = False,
|
|
rot: float = 90,
|
|
grid: bool = True,
|
|
xlabelsize: int | None = None,
|
|
xrot=None,
|
|
ylabelsize: int | None = None,
|
|
yrot=None,
|
|
legend: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Grouped histogram
|
|
|
|
Parameters
|
|
----------
|
|
data : Series/DataFrame
|
|
column : object, optional
|
|
by : object, optional
|
|
ax : axes, optional
|
|
bins : int, default 50
|
|
figsize : tuple, optional
|
|
layout : optional
|
|
sharex : bool, default False
|
|
sharey : bool, default False
|
|
rot : float, default 90
|
|
grid : bool, default True
|
|
legend: : bool, default False
|
|
kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
|
|
|
|
Returns
|
|
-------
|
|
collection of Matplotlib Axes
|
|
"""
|
|
if legend:
|
|
assert "label" not in kwargs
|
|
if data.ndim == 1:
|
|
kwargs["label"] = data.name
|
|
elif column is None:
|
|
kwargs["label"] = data.columns
|
|
else:
|
|
kwargs["label"] = column
|
|
|
|
def plot_group(group, ax) -> None:
|
|
ax.hist(group.dropna().values, bins=bins, **kwargs)
|
|
if legend:
|
|
ax.legend()
|
|
|
|
if xrot is None:
|
|
xrot = rot
|
|
|
|
fig, axes = _grouped_plot(
|
|
plot_group,
|
|
data,
|
|
column=column,
|
|
by=by,
|
|
sharex=sharex,
|
|
sharey=sharey,
|
|
ax=ax,
|
|
figsize=figsize,
|
|
layout=layout,
|
|
rot=rot,
|
|
)
|
|
|
|
set_ticks_props(
|
|
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
|
)
|
|
|
|
maybe_adjust_figure(
|
|
fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
|
|
)
|
|
return axes
|
|
|
|
|
|
def hist_series(
|
|
self: Series,
|
|
by=None,
|
|
ax=None,
|
|
grid: bool = True,
|
|
xlabelsize: int | None = None,
|
|
xrot=None,
|
|
ylabelsize: int | None = None,
|
|
yrot=None,
|
|
figsize: tuple[float, float] | None = None,
|
|
bins: int = 10,
|
|
legend: bool = False,
|
|
**kwds,
|
|
):
|
|
import matplotlib.pyplot as plt
|
|
|
|
if legend and "label" in kwds:
|
|
raise ValueError("Cannot use both legend and label")
|
|
|
|
if by is None:
|
|
if kwds.get("layout", None) is not None:
|
|
raise ValueError("The 'layout' keyword is not supported when 'by' is None")
|
|
# hack until the plotting interface is a bit more unified
|
|
fig = kwds.pop(
|
|
"figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
|
|
)
|
|
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
|
|
fig.set_size_inches(*figsize, forward=True)
|
|
if ax is None:
|
|
ax = fig.gca()
|
|
elif ax.get_figure() != fig:
|
|
raise AssertionError("passed axis not bound to passed figure")
|
|
values = self.dropna().values
|
|
if legend:
|
|
kwds["label"] = self.name
|
|
ax.hist(values, bins=bins, **kwds)
|
|
if legend:
|
|
ax.legend()
|
|
ax.grid(grid)
|
|
axes = np.array([ax])
|
|
|
|
# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
|
|
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
|
|
set_ticks_props(
|
|
axes, # type: ignore[arg-type]
|
|
xlabelsize=xlabelsize,
|
|
xrot=xrot,
|
|
ylabelsize=ylabelsize,
|
|
yrot=yrot,
|
|
)
|
|
|
|
else:
|
|
if "figure" in kwds:
|
|
raise ValueError(
|
|
"Cannot pass 'figure' when using the "
|
|
"'by' argument, since a new 'Figure' instance will be created"
|
|
)
|
|
axes = _grouped_hist(
|
|
self,
|
|
by=by,
|
|
ax=ax,
|
|
grid=grid,
|
|
figsize=figsize,
|
|
bins=bins,
|
|
xlabelsize=xlabelsize,
|
|
xrot=xrot,
|
|
ylabelsize=ylabelsize,
|
|
yrot=yrot,
|
|
legend=legend,
|
|
**kwds,
|
|
)
|
|
|
|
if hasattr(axes, "ndim"):
|
|
if axes.ndim == 1 and len(axes) == 1:
|
|
return axes[0]
|
|
return axes
|
|
|
|
|
|
def hist_frame(
|
|
data: DataFrame,
|
|
column=None,
|
|
by=None,
|
|
grid: bool = True,
|
|
xlabelsize: int | None = None,
|
|
xrot=None,
|
|
ylabelsize: int | None = None,
|
|
yrot=None,
|
|
ax=None,
|
|
sharex: bool = False,
|
|
sharey: bool = False,
|
|
figsize: tuple[float, float] | None = None,
|
|
layout=None,
|
|
bins: int = 10,
|
|
legend: bool = False,
|
|
**kwds,
|
|
):
|
|
if legend and "label" in kwds:
|
|
raise ValueError("Cannot use both legend and label")
|
|
if by is not None:
|
|
axes = _grouped_hist(
|
|
data,
|
|
column=column,
|
|
by=by,
|
|
ax=ax,
|
|
grid=grid,
|
|
figsize=figsize,
|
|
sharex=sharex,
|
|
sharey=sharey,
|
|
layout=layout,
|
|
bins=bins,
|
|
xlabelsize=xlabelsize,
|
|
xrot=xrot,
|
|
ylabelsize=ylabelsize,
|
|
yrot=yrot,
|
|
legend=legend,
|
|
**kwds,
|
|
)
|
|
return axes
|
|
|
|
if column is not None:
|
|
if not isinstance(column, (list, np.ndarray, ABCIndex)):
|
|
column = [column]
|
|
data = data[column]
|
|
# GH32590
|
|
data = data.select_dtypes(
|
|
include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
|
|
)
|
|
naxes = len(data.columns)
|
|
|
|
if naxes == 0:
|
|
raise ValueError(
|
|
"hist method requires numerical or datetime columns, nothing to plot."
|
|
)
|
|
|
|
fig, axes = create_subplots(
|
|
naxes=naxes,
|
|
ax=ax,
|
|
squeeze=False,
|
|
sharex=sharex,
|
|
sharey=sharey,
|
|
figsize=figsize,
|
|
layout=layout,
|
|
)
|
|
_axes = flatten_axes(axes)
|
|
|
|
can_set_label = "label" not in kwds
|
|
|
|
for i, col in enumerate(data.columns):
|
|
ax = _axes[i]
|
|
if legend and can_set_label:
|
|
kwds["label"] = col
|
|
ax.hist(data[col].dropna().values, bins=bins, **kwds)
|
|
ax.set_title(col)
|
|
ax.grid(grid)
|
|
if legend:
|
|
ax.legend()
|
|
|
|
set_ticks_props(
|
|
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
|
)
|
|
maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
|
|
|
|
return axes
|