You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
6.3 KiB
186 lines
6.3 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/28 10:10
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import os
|
|
from os import path
|
|
|
|
import torch
|
|
import pandas as pd
|
|
from matplotlib import pyplot as plt
|
|
|
|
from .tools import time_now
|
|
|
|
from typing import Tuple, Union, Callable, Dict, List
|
|
|
|
|
|
class Logging(object):
|
|
"""The interface for logging experimental results"""
|
|
|
|
def __init__(self, is_pretrain: bool, logging_path: str, datasets: List) -> None:
|
|
# Determine whether it is pre-training
|
|
self.is_pretrain = is_pretrain
|
|
|
|
# Datasets excluded or used
|
|
self.datasets = datasets
|
|
|
|
# The address where the recording is performed
|
|
self.logging_path = logging_path
|
|
|
|
# Get the data dictionary and specific methods of recording
|
|
self.dict, self.logging_epoch = self.init_logging()
|
|
|
|
# Create a TXT file that can be written
|
|
self.text = create_txt_file(
|
|
file_path=self.logging_path, file_name="pretrain.txt"
|
|
)
|
|
|
|
def init_logging(self) -> Tuple[Dict, Callable]:
|
|
"""Returns a dictionary and method of the corresponding form according to the training type"""
|
|
# If it is pre-training
|
|
return {
|
|
"time": [],
|
|
"epoch": [],
|
|
"loss": [],
|
|
"loss_mtm": [],
|
|
"loss_mlm": [],
|
|
"loss_t2s": [],
|
|
"loss_s2t": [],
|
|
}, self.logging_pretrain
|
|
|
|
def logging_pretrain(
|
|
self,
|
|
epoch: int,
|
|
loss: Union[float, torch.Tensor],
|
|
loss_mtm: Union[float, torch.Tensor],
|
|
loss_mlm: Union[float, torch.Tensor],
|
|
loss_t2s: Union[float, torch.Tensor],
|
|
loss_s2t: Union[float, torch.Tensor],
|
|
) -> None:
|
|
"""Logging the training process of the pre-trained model"""
|
|
self.dict["time"].append(time_now()) # Get the current time
|
|
self.dict["epoch"].append(epoch) # Add the current training Epoch
|
|
self.dict["loss"].append(loss) # Get the current unsupervised pre-training loss
|
|
self.dict["loss_mtm"].append(loss_mtm)
|
|
self.dict["loss_mlm"].append(loss_mlm)
|
|
self.dict["loss_t2s"].append(loss_t2s)
|
|
self.dict["loss_s2t"].append(loss_s2t)
|
|
self.logging_txt(epoch, loss, loss_mtm, loss_mlm, loss_t2s, loss_s2t)
|
|
|
|
def logging_txt(
|
|
self,
|
|
epoch: int,
|
|
loss: Union[float, torch.Tensor],
|
|
loss_mtm: Union[float, torch.Tensor],
|
|
loss_mlm: Union[float, torch.Tensor],
|
|
loss_t2s: Union[float, torch.Tensor],
|
|
loss_s2t: Union[float, torch.Tensor],
|
|
) -> None:
|
|
"""Write the results to txt file"""
|
|
content = f"epoch={epoch}, loss={loss}, loss_mtm={loss_mtm}, loss_mlm={loss_mlm}, loss_t2s={loss_t2s}, loss_s2t={loss_s2t}"
|
|
write_to_txt(file_path=self.text, content=content)
|
|
|
|
def dict2csv(self) -> None:
|
|
"""Write the recorded dictionary into a csv file"""
|
|
df = pd.DataFrame(self.dict)
|
|
df.to_csv(path.join(self.logging_path, "logging.csv"), index=False)
|
|
|
|
def plot_results(self) -> None:
|
|
"""Function for visualizing experimental results"""
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 4))
|
|
if self.is_pretrain is True:
|
|
# ax.plot(self.dict["epoch"], self.dict["loss"], color='royalblue', label='loss')
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["loss_mtm"],
|
|
color="tomato",
|
|
label="loss_mtm",
|
|
)
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["loss_mlm"],
|
|
color="royalblue",
|
|
label="loss_mlm",
|
|
)
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["loss_t2s"],
|
|
color="#6FAE45",
|
|
label="loss_t2s",
|
|
)
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["loss_s2t"],
|
|
color="darkorange",
|
|
label="loss_s2t",
|
|
)
|
|
ax.set_xlabel("num_epoch", fontsize=16)
|
|
ax.set_ylabel("loss", fontsize=16)
|
|
ax.legend(loc="best", fontsize=15)
|
|
else:
|
|
ax_twinx = ax.twinx()
|
|
ax.set_xlabel("num_epoch", fontsize=16)
|
|
ax.set_ylabel("loss", fontsize=16)
|
|
ax_twinx.set_ylabel("metric", fontsize=16)
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["train_loss"],
|
|
color="royalblue",
|
|
label="Train Loss",
|
|
)
|
|
ax.plot(
|
|
self.dict["epoch"],
|
|
self.dict["test_loss"],
|
|
color="tomato",
|
|
label="Test Loss",
|
|
)
|
|
ax.legend(loc="best", fontsize=15)
|
|
ax_twinx.plot(
|
|
self.dict["epoch"],
|
|
self.dict["train_metric"],
|
|
color="royalblue",
|
|
label="Train Metric",
|
|
)
|
|
ax_twinx.plot(
|
|
self.dict["epoch"],
|
|
self.dict["test_metric"],
|
|
color="tomato",
|
|
label="Test Metric",
|
|
)
|
|
fig.savefig(
|
|
path.join(self.logging_path, "plot.jpg"), bbox_inches="tight", dpi=900
|
|
)
|
|
|
|
|
|
def create_txt_file(file_path: str, file_name: str) -> str:
|
|
"""
|
|
Creates a TXT file in the specified directory.
|
|
|
|
:param file_path: The directory path where the file is to be created.
|
|
:param file_name: The name of the file to be created (including the .txt extension).
|
|
"""
|
|
if not os.path.exists(file_path):
|
|
assert OSError
|
|
# Full file path
|
|
full_file_path = os.path.join(file_path, file_name)
|
|
# Create and open a file
|
|
with open(full_file_path, "w", encoding="utf-8") as file:
|
|
pass # Create a file without writing anything
|
|
return full_file_path
|
|
|
|
|
|
def write_to_txt(file_path: str, content: str, mode: str = "a") -> None:
|
|
"""
|
|
Writes content to a TXT file based on the passed parameters.
|
|
|
|
:param file_path: The full path to the file
|
|
:param content: The content to be written
|
|
:param mode: The write mode. Defaults to 'a' (append mode). 'w' overwrites the existing content.
|
|
"""
|
|
with open(file_path, mode, encoding="utf-8") as file:
|
|
file.write(content + "\n") # Write the content and add a newline at the end
|