from typing import List, Callable, Union, Any, TypeVar, Tuple # from torch import tensor as Tensor Tensor = TypeVar('torch.tensor')