You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
934 B
33 lines
934 B
5 months ago
|
import enum
|
||
|
from typing import Sequence, Type, TypeVar
|
||
|
|
||
|
T = TypeVar("T", bound=enum.Enum)
|
||
|
|
||
|
|
||
|
class StrEnumMeta(enum.EnumMeta):
|
||
|
auto = enum.auto
|
||
|
|
||
|
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
|
||
|
try:
|
||
|
return self[member]
|
||
|
except KeyError:
|
||
|
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
|
||
|
# soon as it is migrated.
|
||
|
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
|
||
|
|
||
|
|
||
|
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
|
||
|
if not seq:
|
||
|
return ""
|
||
|
if len(seq) == 1:
|
||
|
return f"'{seq[0]}'"
|
||
|
|
||
|
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
|
||
|
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
|
||
|
|
||
|
return head + tail
|