Fix loss explosion?

This commit is contained in:
Andras Schmelczer 2024-06-25 09:00:20 +01:00
parent d1b6e52f31
commit 856dc83c77
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 15 additions and 13 deletions

View file

@ -2,6 +2,9 @@ import torch
import torch.nn as nn
EPSILON = 1e-5
class Residual3(nn.Module):
def __init__(
self,
@ -124,11 +127,11 @@ class Residual3(nn.Module):
out = self.deconv2(out)
out = self.deconv3(out)
return self._normalize(out)
return out
def _normalize(self, x):
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
return x / (x_sum + 1e-6)
# def _normalize(self, x):
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
# return x / (x_sum + EPSILON)
def _initialize_weights(self):
for m in self.modules():

View file

@ -4,19 +4,21 @@ from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from .get_next_run_name import get_next_run_name
from visualisation import plot_histograms_in_2d
from models import create_model, save_model
from models import create_model
from datetime import timedelta, datetime
from torch.utils.data import DataLoader
import torch
from utils import serialise_hparams
EPSILON = 1e-5
def train(
hyperparameters: Dict[str, Any],
training_data_path: DataLoader,
test_data_path: DataLoader,
train_data_loader: DataLoader,
test_data_loader: DataLoader,
log_dir: Path,
max_duration: Optional[timedelta],
use_tqdm: bool,
@ -31,9 +33,6 @@ def train(
start_time = datetime.now()
with SummaryWriter(log_dir) as writer:
train_data_loader = training_data_path
test_data_loader = test_data_path
model = create_model(
type=model_type,
bin_count=bin_count,
@ -65,7 +64,7 @@ def train(
optimizer.zero_grad()
predicted_original = model(edited_histogram.to(device))
loss = loss_function(
torch.log(torch.clamp(predicted_original, 1e-5, 1)),
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
original_histogram.to(device),
)
@ -102,7 +101,7 @@ def train(
):
predicted_original = model(edited_histogram.to(device))
epoch_test_loss += loss_function(
torch.log(torch.clamp(predicted_original, 1e-10, 1)),
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
original_histogram.to(device),
).item()
writer.add_hparams(