Fix bugs
This commit is contained in:
parent
ae2995d0e9
commit
137ba1c475
2 changed files with 12 additions and 18 deletions
|
|
@ -28,7 +28,6 @@ class HistogramNet(nn.Module):
|
|||
self._use_elu = use_elu
|
||||
self._leaky_relu_alpha = leaky_relu_alpha
|
||||
self._use_residual = use_residual
|
||||
self.print_og_result = False
|
||||
|
||||
self._convolutions = nn.ModuleList(
|
||||
self._make_conv_layer(in_channels=in_channels, out_channels=out_channels)
|
||||
|
|
@ -55,7 +54,7 @@ class HistogramNet(nn.Module):
|
|||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self._kernel_size,
|
||||
padding=1,
|
||||
padding=self._kernel_size // 2,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
|
|
@ -75,21 +74,22 @@ class HistogramNet(nn.Module):
|
|||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=self._kernel_size,
|
||||
padding=1,
|
||||
padding=self._kernel_size // 2,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
if self._use_elu
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)(
|
||||
nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d
|
||||
)(channels)
|
||||
else nn.LeakyReLU(self._leaky_relu_alpha)
|
||||
),
|
||||
(nn.InstanceNorm3d if self._use_instance_norm else nn.BatchNorm3d)(
|
||||
channels
|
||||
),
|
||||
nn.Conv3d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=self._kernel_size,
|
||||
padding=1,
|
||||
padding=self._kernel_size // 2,
|
||||
bias=False,
|
||||
),
|
||||
(
|
||||
|
|
@ -108,7 +108,7 @@ class HistogramNet(nn.Module):
|
|||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self._kernel_size,
|
||||
padding=1,
|
||||
padding=self._kernel_size // 2,
|
||||
),
|
||||
(
|
||||
nn.ELU(self._elu_alpha)
|
||||
|
|
@ -129,10 +129,6 @@ class HistogramNet(nn.Module):
|
|||
for deconv in self._deconvolutions:
|
||||
x = deconv(x)
|
||||
|
||||
if self.print_og_result:
|
||||
logging.info(f"Original result {torch.sum(x)}")
|
||||
self.print_og_result = False
|
||||
|
||||
return self._normalize(x)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -144,7 +140,6 @@ class HistogramNet(nn.Module):
|
|||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
|
||||
# Applying He normal initialization
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ def random_hparam_search(
|
|||
device: torch.device,
|
||||
) -> None:
|
||||
for _ in count():
|
||||
run_id = get_next_run_name(tensorboard_path)
|
||||
current_hyperparameters = {
|
||||
k: v.rvs() if hasattr(v, "rvs") else choice(v)
|
||||
for k, v in choice(hyperparameters).items()
|
||||
|
|
@ -29,11 +30,9 @@ def random_hparam_search(
|
|||
serialized_hparams = json.dumps(
|
||||
current_hyperparameters, indent=2, sort_keys=True
|
||||
)
|
||||
logging.info(
|
||||
f"Starting {get_next_run_name(tensorboard_path)} with hparams {serialized_hparams}"
|
||||
)
|
||||
logging.info(f"Starting {run_id} with hparams {serialized_hparams}")
|
||||
|
||||
log_dir = tensorboard_path / get_next_run_name(tensorboard_path)
|
||||
log_dir = tensorboard_path / run_id
|
||||
|
||||
try:
|
||||
model = train(
|
||||
|
|
@ -46,7 +45,7 @@ def random_hparam_search(
|
|||
device=device,
|
||||
**current_hyperparameters,
|
||||
)
|
||||
model_path = models_path / get_next_run_name(models_path)
|
||||
model_path = models_path / run_id
|
||||
save_model(model, current_hyperparameters, model_path)
|
||||
del model
|
||||
except KeyboardInterrupt as e:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue