You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
3.2 KiB

6 months ago
import pytest
from pandas import (
DataFrame,
Series,
option_context,
)
import pandas._testing as tm
pytestmark = pytest.mark.single_cpu
pytest.importorskip("numba")
@pytest.mark.filterwarnings("ignore")
# Filter warnings when parallel=True and the function can't be parallelized by Numba
class TestEngine:
def test_cython_vs_numba_frame(
self, sort, nogil, parallel, nopython, numba_supported_reductions
):
func, kwargs = numba_supported_reductions
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
gb = df.groupby("a", sort=sort)
result = getattr(gb, func)(
engine="numba", engine_kwargs=engine_kwargs, **kwargs
)
expected = getattr(gb, func)(**kwargs)
tm.assert_frame_equal(result, expected)
def test_cython_vs_numba_getitem(
self, sort, nogil, parallel, nopython, numba_supported_reductions
):
func, kwargs = numba_supported_reductions
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
gb = df.groupby("a", sort=sort)["c"]
result = getattr(gb, func)(
engine="numba", engine_kwargs=engine_kwargs, **kwargs
)
expected = getattr(gb, func)(**kwargs)
tm.assert_series_equal(result, expected)
def test_cython_vs_numba_series(
self, sort, nogil, parallel, nopython, numba_supported_reductions
):
func, kwargs = numba_supported_reductions
ser = Series(range(3), index=[1, 2, 1], name="foo")
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
gb = ser.groupby(level=0, sort=sort)
result = getattr(gb, func)(
engine="numba", engine_kwargs=engine_kwargs, **kwargs
)
expected = getattr(gb, func)(**kwargs)
tm.assert_series_equal(result, expected)
def test_as_index_false_unsupported(self, numba_supported_reductions):
func, kwargs = numba_supported_reductions
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
gb = df.groupby("a", as_index=False)
with pytest.raises(NotImplementedError, match="as_index=False"):
getattr(gb, func)(engine="numba", **kwargs)
def test_axis_1_unsupported(self, numba_supported_reductions):
func, kwargs = numba_supported_reductions
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
gb = df.groupby("a", axis=1)
with pytest.raises(NotImplementedError, match="axis=1"):
getattr(gb, func)(engine="numba", **kwargs)
def test_no_engine_doesnt_raise(self):
# GH55520
df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
gb = df.groupby("a")
# Make sure behavior of functions w/out engine argument don't raise
# when the global use_numba option is set
with option_context("compute.use_numba", True):
res = gb.agg({"b": "first"})
expected = gb.agg({"b": "first"})
tm.assert_frame_equal(res, expected)