You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

28 lines
731 B

# -*- coding: utf-8 -*-
"""
Get the interface module of the loss function.
Created on 2024/9/23 17:00
@author: Whenxuan Wang
@email: wwhenxuan@gmail.com
@url: https://github.com/wwhenxuan/SymTime
"""
from torch import nn
from typing import Callable
def get_criterion(name: str = "MSE") -> Callable:
"""Get the interface configuration of the neural network loss function"""
if name == "MSE":
return nn.MSELoss
elif name == "MAE":
return nn.L1Loss
elif name == "CEL":
return nn.CrossEntropyLoss
elif name == "Huber":
return nn.SmoothL1Loss
elif name == "Cos":
return nn.CosineEmbeddingLoss
else:
raise ValueError("The loss function name is incorrect.!")