Fix loss explosion?
This commit is contained in:
parent
d1b6e52f31
commit
856dc83c77
2 changed files with 15 additions and 13 deletions
|
|
@ -2,6 +2,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
class Residual3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -124,11 +127,11 @@ class Residual3(nn.Module):
|
|||
out = self.deconv2(out)
|
||||
out = self.deconv3(out)
|
||||
|
||||
return self._normalize(out)
|
||||
return out
|
||||
|
||||
def _normalize(self, x):
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + 1e-6)
|
||||
# def _normalize(self, x):
|
||||
# x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
# return x / (x_sum + EPSILON)
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
|
|||
|
|
@ -4,19 +4,21 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
from pathlib import Path
|
||||
from torch.optim import Adam
|
||||
from tqdm.notebook import tqdm
|
||||
from .get_next_run_name import get_next_run_name
|
||||
from visualisation import plot_histograms_in_2d
|
||||
from models import create_model, save_model
|
||||
from models import create_model
|
||||
from datetime import timedelta, datetime
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils import serialise_hparams
|
||||
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
||||
def train(
|
||||
hyperparameters: Dict[str, Any],
|
||||
training_data_path: DataLoader,
|
||||
test_data_path: DataLoader,
|
||||
train_data_loader: DataLoader,
|
||||
test_data_loader: DataLoader,
|
||||
log_dir: Path,
|
||||
max_duration: Optional[timedelta],
|
||||
use_tqdm: bool,
|
||||
|
|
@ -31,9 +33,6 @@ def train(
|
|||
start_time = datetime.now()
|
||||
|
||||
with SummaryWriter(log_dir) as writer:
|
||||
train_data_loader = training_data_path
|
||||
test_data_loader = test_data_path
|
||||
|
||||
model = create_model(
|
||||
type=model_type,
|
||||
bin_count=bin_count,
|
||||
|
|
@ -65,7 +64,7 @@ def train(
|
|||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
torch.log(torch.clamp(predicted_original, 1e-5, 1)),
|
||||
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
||||
original_histogram.to(device),
|
||||
)
|
||||
|
||||
|
|
@ -102,7 +101,7 @@ def train(
|
|||
):
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
epoch_test_loss += loss_function(
|
||||
torch.log(torch.clamp(predicted_original, 1e-10, 1)),
|
||||
torch.log(torch.clamp(predicted_original, EPSILON, 1)),
|
||||
original_histogram.to(device),
|
||||
).item()
|
||||
writer.add_hparams(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue