You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
731 B
28 lines
731 B
# -*- coding: utf-8 -*-
|
|
"""
|
|
Get the interface module of the loss function.
|
|
|
|
Created on 2024/9/23 17:00
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
from torch import nn
|
|
from typing import Callable
|
|
|
|
|
|
def get_criterion(name: str = "MSE") -> Callable:
|
|
"""Get the interface configuration of the neural network loss function"""
|
|
if name == "MSE":
|
|
return nn.MSELoss
|
|
elif name == "MAE":
|
|
return nn.L1Loss
|
|
elif name == "CEL":
|
|
return nn.CrossEntropyLoss
|
|
elif name == "Huber":
|
|
return nn.SmoothL1Loss
|
|
elif name == "Cos":
|
|
return nn.CosineEmbeddingLoss
|
|
else:
|
|
raise ValueError("The loss function name is incorrect.!")
|