You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
5.8 KiB

7 months ago
from __future__ import annotations
import datetime as dt
from typing import (
TYPE_CHECKING,
Any,
cast,
)
import numpy as np
from pandas.core.dtypes.dtypes import register_extension_dtype
from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
)
from pandas.api.types import pandas_dtype
if TYPE_CHECKING:
from collections.abc import Sequence
from pandas._typing import (
Dtype,
PositionalIndexer,
)
@register_extension_dtype
class DateDtype(ExtensionDtype):
@property
def type(self):
return dt.date
@property
def name(self):
return "DateDtype"
@classmethod
def construct_from_string(cls, string: str):
if not isinstance(string, str):
raise TypeError(
f"'construct_from_string' expects a string, got {type(string)}"
)
if string == cls.__name__:
return cls()
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@classmethod
def construct_array_type(cls):
return DateArray
@property
def na_value(self):
return dt.date.min
def __repr__(self) -> str:
return self.name
class DateArray(ExtensionArray):
def __init__(
self,
dates: (
dt.date
| Sequence[dt.date]
| tuple[np.ndarray, np.ndarray, np.ndarray]
| np.ndarray
),
) -> None:
if isinstance(dates, dt.date):
self._year = np.array([dates.year])
self._month = np.array([dates.month])
self._day = np.array([dates.year])
return
ldates = len(dates)
if isinstance(dates, list):
# pre-allocate the arrays since we know the size before hand
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
# populate them
for i, (y, m, d) in enumerate(
(date.year, date.month, date.day) for date in dates
):
self._year[i] = y
self._month[i] = m
self._day[i] = d
elif isinstance(dates, tuple):
# only support triples
if ldates != 3:
raise ValueError("only triples are valid")
# check if all elements have the same type
if any(not isinstance(x, np.ndarray) for x in dates):
raise TypeError("invalid type")
ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
if not ly == lm == ld:
raise ValueError(
f"tuple members must have the same length: {(ly, lm, ld)}"
)
self._year = dates[0].astype(np.uint16)
self._month = dates[1].astype(np.uint8)
self._day = dates[2].astype(np.uint8)
elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
# error: "object_" object is not iterable
obj = np.char.split(dates, sep="-")
for (i,), (y, m, d) in np.ndenumerate(obj): # type: ignore[misc]
self._year[i] = int(y)
self._month[i] = int(m)
self._day[i] = int(d)
else:
raise TypeError(f"{type(dates)} is not supported")
@property
def dtype(self) -> ExtensionDtype:
return DateDtype()
def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)
if isinstance(dtype, DateDtype):
data = self.copy() if copy else self
else:
data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
return data
@property
def nbytes(self) -> int:
return self._year.nbytes + self._month.nbytes + self._day.nbytes
def __len__(self) -> int:
return len(self._year) # all 3 arrays are enforced to have the same length
def __getitem__(self, item: PositionalIndexer):
if isinstance(item, int):
return dt.date(self._year[item], self._month[item], self._day[item])
else:
raise NotImplementedError("only ints are supported as indexes")
def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
if not isinstance(key, int):
raise NotImplementedError("only ints are supported as indexes")
if not isinstance(value, dt.date):
raise TypeError("you can only set datetime.date types")
self._year[key] = value.year
self._month[key] = value.month
self._day[key] = value.day
def __repr__(self) -> str:
return f"DateArray{list(zip(self._year, self._month, self._day))}"
def copy(self) -> DateArray:
return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
def isna(self) -> np.ndarray:
return np.logical_and(
np.logical_and(
self._year == dt.date.min.year, self._month == dt.date.min.month
),
self._day == dt.date.min.day,
)
@classmethod
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
if isinstance(scalars, dt.date):
raise TypeError
elif isinstance(scalars, DateArray):
if dtype is not None:
return scalars.astype(dtype, copy=copy)
if copy:
return scalars.copy()
return scalars[:]
elif isinstance(scalars, np.ndarray):
scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
return DateArray(scalars)