From 856dc83c774e9ce3b4d012638be5ca8192441519 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Tue, 25 Jun 2024 09:00:20 +0100 Subject: [PATCH] Fix loss explosion? --- src/models/residual3.py | 11 +++++++---- src/training/train.py | 17 ++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/models/residual3.py b/src/models/residual3.py index a492796..349c52b 100644 --- a/src/models/residual3.py +++ b/src/models/residual3.py @@ -2,6 +2,9 @@ import torch import torch.nn as nn +EPSILON = 1e-5 + + class Residual3(nn.Module): def __init__( self, @@ -124,11 +127,11 @@ class Residual3(nn.Module): out = self.deconv2(out) out = self.deconv3(out) - return self._normalize(out) + return out - def _normalize(self, x): - x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) - return x / (x_sum + 1e-6) + # def _normalize(self, x): + # x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True) + # return x / (x_sum + EPSILON) def _initialize_weights(self): for m in self.modules(): diff --git a/src/training/train.py b/src/training/train.py index 4d4f49b..e11ad14 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -4,19 +4,21 @@ from torch.utils.tensorboard import SummaryWriter from pathlib import Path from torch.optim import Adam from tqdm.notebook import tqdm -from .get_next_run_name import get_next_run_name from visualisation import plot_histograms_in_2d -from models import create_model, save_model +from models import create_model from datetime import timedelta, datetime from torch.utils.data import DataLoader import torch from utils import serialise_hparams +EPSILON = 1e-5 + + def train( hyperparameters: Dict[str, Any], - training_data_path: DataLoader, - test_data_path: DataLoader, + train_data_loader: DataLoader, + test_data_loader: DataLoader, log_dir: Path, max_duration: Optional[timedelta], use_tqdm: bool, @@ -31,9 +33,6 @@ def train( start_time = datetime.now() with SummaryWriter(log_dir) as writer: - train_data_loader = training_data_path - test_data_loader = test_data_path - model = create_model( type=model_type, bin_count=bin_count, @@ -65,7 +64,7 @@ def train( optimizer.zero_grad() predicted_original = model(edited_histogram.to(device)) loss = loss_function( - torch.log(torch.clamp(predicted_original, 1e-5, 1)), + torch.log(torch.clamp(predicted_original, EPSILON, 1)), original_histogram.to(device), ) @@ -102,7 +101,7 @@ def train( ): predicted_original = model(edited_histogram.to(device)) epoch_test_loss += loss_function( - torch.log(torch.clamp(predicted_original, 1e-10, 1)), + torch.log(torch.clamp(predicted_original, EPSILON, 1)), original_histogram.to(device), ).item() writer.add_hparams(