Fix errors
This commit is contained in:
parent
7863611f86
commit
35eb747abf
3 changed files with 26 additions and 73 deletions
|
|
@ -47,7 +47,7 @@ def load_model(path: Path, device: torch.device) -> Tuple[nn.Module, Dict[str, A
|
|||
model_path = path.with_suffix(".pth")
|
||||
model = create_model(
|
||||
type=hyperparameters["model_type"],
|
||||
**hyperparameters,
|
||||
hyperparameters=hyperparameters,
|
||||
device=device,
|
||||
)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
|
|
|
|||
|
|
@ -33,5 +33,10 @@ class SimpleCNN(nn.Module):
|
|||
x = F.relu(self.conv4(x))
|
||||
x = F.relu(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / sum
|
||||
return self._normalize(x)
|
||||
|
||||
@staticmethod
|
||||
def _normalize(x):
|
||||
x = torch.clamp(x, min=0)
|
||||
x_sum = torch.sum(x, dim=(2, 3, 4), keepdim=True)
|
||||
return x / (x_sum + 1e-5)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue