You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

639 lines
17 KiB

6 months ago
from __future__ import annotations
from decimal import Decimal
import operator
import os
from sys import byteorder
from typing import (
TYPE_CHECKING,
Callable,
ContextManager,
cast,
)
import warnings
import numpy as np
from pandas._config.localization import (
can_set_locale,
get_locales,
set_locale,
)
from pandas.compat import pa_version_under10p1
from pandas.core.dtypes.common import is_string_dtype
import pandas as pd
from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
RangeIndex,
Series,
)
from pandas._testing._io import (
round_trip_localpath,
round_trip_pathlib,
round_trip_pickle,
write_to_compressed,
)
from pandas._testing._warnings import (
assert_produces_warning,
maybe_produces_warning,
)
from pandas._testing.asserters import (
assert_almost_equal,
assert_attr_equal,
assert_categorical_equal,
assert_class_equal,
assert_contains_all,
assert_copy,
assert_datetime_array_equal,
assert_dict_equal,
assert_equal,
assert_extension_array_equal,
assert_frame_equal,
assert_index_equal,
assert_indexing_slices_equivalent,
assert_interval_array_equal,
assert_is_sorted,
assert_is_valid_plot_return_object,
assert_metadata_equivalent,
assert_numpy_array_equal,
assert_period_array_equal,
assert_series_equal,
assert_sp_array_equal,
assert_timedelta_array_equal,
raise_assert_detail,
)
from pandas._testing.compat import (
get_dtype,
get_obj,
)
from pandas._testing.contexts import (
assert_cow_warning,
decompress_file,
ensure_clean,
raises_chained_assignment_error,
set_timezone,
use_numexpr,
with_csv_dialect,
)
from pandas.core.arrays import (
BaseMaskedArray,
ExtensionArray,
NumpyExtensionArray,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.construction import extract_array
if TYPE_CHECKING:
from pandas._typing import (
Dtype,
NpDtype,
)
from pandas.core.arrays import ArrowExtensionArray
UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
STRING_DTYPES: list[Dtype] = [str, "str", "U"]
DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
BOOL_DTYPES: list[Dtype] = [bool, "bool"]
BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
OBJECT_DTYPES: list[Dtype] = [object, "object"]
ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
ALL_NUMPY_DTYPES = (
ALL_REAL_NUMPY_DTYPES
+ COMPLEX_DTYPES
+ STRING_DTYPES
+ DATETIME64_DTYPES
+ TIMEDELTA64_DTYPES
+ BOOL_DTYPES
+ OBJECT_DTYPES
+ BYTES_DTYPES
)
NARROW_NP_DTYPES = [
np.float16,
np.float32,
np.int8,
np.int16,
np.int32,
np.uint8,
np.uint16,
np.uint32,
]
PYTHON_DATA_TYPES = [
str,
int,
float,
complex,
list,
tuple,
range,
dict,
set,
frozenset,
bool,
bytes,
bytearray,
memoryview,
]
ENDIAN = {"little": "<", "big": ">"}[byteorder]
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
NP_NAT_OBJECTS = [
cls("NaT", unit)
for cls in [np.datetime64, np.timedelta64]
for unit in [
"Y",
"M",
"W",
"D",
"h",
"m",
"s",
"ms",
"us",
"ns",
"ps",
"fs",
"as",
]
]
if not pa_version_under10p1:
import pyarrow as pa
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
ALL_INT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
]
# pa.float16 doesn't seem supported
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]
TIME_PYARROW_DTYPES = [
pa.time32("s"),
pa.time32("ms"),
pa.time64("us"),
pa.time64("ns"),
]
DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
DATETIME_PYARROW_DTYPES = [
pa.timestamp(unit=unit, tz=tz)
for unit in ["s", "ms", "us", "ns"]
for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
]
TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
BOOL_PYARROW_DTYPES = [pa.bool_()]
# TODO: Add container like pyarrow types:
# https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ DECIMAL_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
+ DATE_PYARROW_DTYPES
+ DATETIME_PYARROW_DTYPES
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []
ALL_PYARROW_DTYPES = []
ALL_REAL_PYARROW_DTYPES_STR_REPR = []
ALL_REAL_NULLABLE_DTYPES = (
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
)
arithmetic_dunder_methods = [
"__add__",
"__radd__",
"__sub__",
"__rsub__",
"__mul__",
"__rmul__",
"__floordiv__",
"__rfloordiv__",
"__truediv__",
"__rtruediv__",
"__pow__",
"__rpow__",
"__mod__",
"__rmod__",
]
comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
# -----------------------------------------------------------------------------
# Comparators
def box_expected(expected, box_cls, transpose: bool = True):
"""
Helper function to wrap the expected output of a test in a given box_class.
Parameters
----------
expected : np.ndarray, Index, Series
box_cls : {Index, Series, DataFrame}
Returns
-------
subclass of box_cls
"""
if box_cls is pd.array:
if isinstance(expected, RangeIndex):
# pd.array would return an IntegerArray
expected = NumpyExtensionArray(np.asarray(expected._values))
else:
expected = pd.array(expected, copy=False)
elif box_cls is Index:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
expected = Index(expected)
elif box_cls is Series:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
expected = Series(expected)
elif box_cls is DataFrame:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Dtype inference", category=FutureWarning)
expected = Series(expected).to_frame()
if transpose:
# for vector operations, we need a DataFrame to be a single-row,
# not a single-column, in order to operate against non-DataFrame
# vectors of the same length. But convert to two rows to avoid
# single-row special cases in datetime arithmetic
expected = expected.T
expected = pd.concat([expected] * 2, ignore_index=True)
elif box_cls is np.ndarray or box_cls is np.array:
expected = np.array(expected)
elif box_cls is to_array:
expected = to_array(expected)
else:
raise NotImplementedError(box_cls)
return expected
def to_array(obj):
"""
Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
"""
# temporary implementation until we get pd.array in place
dtype = getattr(obj, "dtype", None)
if dtype is None:
return np.asarray(obj)
return extract_array(obj, extract_numpy=True)
class SubclassedSeries(Series):
_metadata = ["testattr", "name"]
@property
def _constructor(self):
# For testing, those properties return a generic callable, and not
# the actual class. In this case that is equivalent, but it is to
# ensure we don't rely on the property returning a class
# See https://github.com/pandas-dev/pandas/pull/46018 and
# https://github.com/pandas-dev/pandas/issues/32638 and linked issues
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
@property
def _constructor_expanddim(self):
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
class SubclassedDataFrame(DataFrame):
_metadata = ["testattr"]
@property
def _constructor(self):
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
@property
def _constructor_sliced(self):
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
"""
Convert list of CSV rows to single CSV-formatted string for current OS.
This method is used for creating expected value of to_csv() method.
Parameters
----------
rows_list : List[str]
Each element represents the row of csv.
Returns
-------
str
Expected output of to_csv() in current OS.
"""
sep = os.linesep
return sep.join(rows_list) + sep
def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
"""
Helper function to mark pytest.raises that have an external error message.
Parameters
----------
expected_exception : Exception
Expected error to raise.
Returns
-------
Callable
Regular `pytest.raises` function with `match` equal to `None`.
"""
import pytest
return pytest.raises(expected_exception, match=None)
cython_table = pd.core.common._cython_table.items()
def get_cython_table_params(ndframe, func_names_and_expected):
"""
Combine frame, functions from com._cython_table
keys and expected result.
Parameters
----------
ndframe : DataFrame or Series
func_names_and_expected : Sequence of two items
The first item is a name of a NDFrame method ('sum', 'prod') etc.
The second item is the expected return value.
Returns
-------
list
List of three items (DataFrame, function, expected result)
"""
results = []
for func_name, expected in func_names_and_expected:
results.append((ndframe, func_name, expected))
results += [
(ndframe, func, expected)
for func, name in cython_table
if name == func_name
]
return results
def get_op_from_name(op_name: str) -> Callable:
"""
The operator function for a given op name.
Parameters
----------
op_name : str
The op name, in form of "add" or "__add__".
Returns
-------
function
A function performing the operation.
"""
short_opname = op_name.strip("_")
try:
op = getattr(operator, short_opname)
except AttributeError:
# Assume it is the reverse operator
rop = getattr(operator, short_opname[1:])
op = lambda x, y: rop(y, x)
return op
# -----------------------------------------------------------------------------
# Indexing test helpers
def getitem(x):
return x
def setitem(x):
return x
def loc(x):
return x.loc
def iloc(x):
return x.iloc
def at(x):
return x.at
def iat(x):
return x.iat
# -----------------------------------------------------------------------------
_UNITS = ["s", "ms", "us", "ns"]
def get_finest_unit(left: str, right: str):
"""
Find the higher of two datetime64 units.
"""
if _UNITS.index(left) >= _UNITS.index(right):
return left
return right
def shares_memory(left, right) -> bool:
"""
Pandas-compat for np.shares_memory.
"""
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return np.shares_memory(left, right)
elif isinstance(left, np.ndarray):
# Call with reversed args to get to unpacking logic below.
return shares_memory(right, left)
if isinstance(left, RangeIndex):
return False
if isinstance(left, MultiIndex):
return shares_memory(left._codes, right)
if isinstance(left, (Index, Series)):
return shares_memory(left._values, right)
if isinstance(left, NDArrayBackedExtensionArray):
return shares_memory(left._ndarray, right)
if isinstance(left, pd.core.arrays.SparseArray):
return shares_memory(left.sp_values, right)
if isinstance(left, pd.core.arrays.IntervalArray):
return shares_memory(left._left, right) or shares_memory(left._right, right)
if (
isinstance(left, ExtensionArray)
and is_string_dtype(left.dtype)
and left.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
):
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
left = cast("ArrowExtensionArray", left)
if (
isinstance(right, ExtensionArray)
and is_string_dtype(right.dtype)
and right.dtype.storage in ("pyarrow", "pyarrow_numpy") # type: ignore[attr-defined]
):
right = cast("ArrowExtensionArray", right)
left_pa_data = left._pa_array
right_pa_data = right._pa_array
left_buf1 = left_pa_data.chunk(0).buffers()[1]
right_buf1 = right_pa_data.chunk(0).buffers()[1]
return left_buf1 == right_buf1
if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
# By convention, we'll say these share memory if they share *either*
# the _data or the _mask
return np.shares_memory(left._data, right._data) or np.shares_memory(
left._mask, right._mask
)
if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
arr = left._mgr.arrays[0]
return shares_memory(arr, right)
raise NotImplementedError(type(left), type(right))
__all__ = [
"ALL_INT_EA_DTYPES",
"ALL_INT_NUMPY_DTYPES",
"ALL_NUMPY_DTYPES",
"ALL_REAL_NUMPY_DTYPES",
"assert_almost_equal",
"assert_attr_equal",
"assert_categorical_equal",
"assert_class_equal",
"assert_contains_all",
"assert_copy",
"assert_datetime_array_equal",
"assert_dict_equal",
"assert_equal",
"assert_extension_array_equal",
"assert_frame_equal",
"assert_index_equal",
"assert_indexing_slices_equivalent",
"assert_interval_array_equal",
"assert_is_sorted",
"assert_is_valid_plot_return_object",
"assert_metadata_equivalent",
"assert_numpy_array_equal",
"assert_period_array_equal",
"assert_produces_warning",
"assert_series_equal",
"assert_sp_array_equal",
"assert_timedelta_array_equal",
"assert_cow_warning",
"at",
"BOOL_DTYPES",
"box_expected",
"BYTES_DTYPES",
"can_set_locale",
"COMPLEX_DTYPES",
"convert_rows_list_to_csv_str",
"DATETIME64_DTYPES",
"decompress_file",
"ENDIAN",
"ensure_clean",
"external_error_raised",
"FLOAT_EA_DTYPES",
"FLOAT_NUMPY_DTYPES",
"get_cython_table_params",
"get_dtype",
"getitem",
"get_locales",
"get_finest_unit",
"get_obj",
"get_op_from_name",
"iat",
"iloc",
"loc",
"maybe_produces_warning",
"NARROW_NP_DTYPES",
"NP_NAT_OBJECTS",
"NULL_OBJECTS",
"OBJECT_DTYPES",
"raise_assert_detail",
"raises_chained_assignment_error",
"round_trip_localpath",
"round_trip_pathlib",
"round_trip_pickle",
"setitem",
"set_locale",
"set_timezone",
"shares_memory",
"SIGNED_INT_EA_DTYPES",
"SIGNED_INT_NUMPY_DTYPES",
"STRING_DTYPES",
"SubclassedDataFrame",
"SubclassedSeries",
"TIMEDELTA64_DTYPES",
"to_array",
"UNSIGNED_INT_EA_DTYPES",
"UNSIGNED_INT_NUMPY_DTYPES",
"use_numexpr",
"with_csv_dialect",
"write_to_compressed",
]