Remove clutter

This commit is contained in:
Andras Schmelczer 2024-09-01 22:10:56 +01:00
parent 68397b565a
commit 49d9ece2ec
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
4 changed files with 6 additions and 68 deletions

File diff suppressed because one or more lines are too long

View file

@ -1,4 +1,4 @@
from .histogram_dataset import HistogramDataset
from .get_next_run_name import get_next_run_name
from .random_hparam_search import random_hparam_search
from .train import train
from .train_with_ray import train_with_ray_factory

View file

@ -56,10 +56,7 @@ class HistogramDataset(Dataset):
original_idx = idx // self._edit_count
edit_idx = idx % self._edit_count
edited_histogram = None
original_histogram = None
cached_data_path = None
if self._cache_path is not None:
cached_data_path = self._cache_path / str(original_idx) / f"{edit_idx}.bin"
cached_data_path.parent.mkdir(parents=True, exist_ok=True)
@ -89,13 +86,11 @@ class HistogramDataset(Dataset):
cached_data_path,
)
result = (
return (
torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0),
torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0),
)
return result
@staticmethod
def save_2_histograms(tensor1: np.ndarray, tensor2: np.ndarray, path: Path):
flat_array1 = tensor1.flatten().astype(np.float32)

View file

@ -1,57 +0,0 @@
from datetime import timedelta
import logging
from pathlib import Path
from random import choice
from itertools import count
import json
from typing import Any, Dict, List
from .train import train
from .get_next_run_name import get_next_run_name
from models import save_model
from .get_data_loader import get_data_loader
import torch
def random_hparam_search(
hyperparameters: List[Dict[str, Any]],
train_data_paths: List[Path],
test_data_paths: List[Path],
models_path: Path,
tensorboard_path: Path,
timeout_hours: int,
device: torch.device,
) -> None:
for _ in count():
run_id = get_next_run_name(tensorboard_path)
current_hyperparameters = {
k: v.rvs() if hasattr(v, "rvs") else choice(v)
for k, v in choice(hyperparameters).items()
}
serialized_hparams = json.dumps(
current_hyperparameters, indent=2, sort_keys=True
)
logging.info(f"Starting {run_id} with hparams {serialized_hparams}")
log_dir = tensorboard_path / run_id
try:
model = train(
hyperparameters=current_hyperparameters,
train_data_paths=train_data_paths,
test_data_paths=test_data_paths,
max_duration=timedelta(hours=timeout_hours),
log_dir=log_dir,
use_tqdm=False,
device=device,
**current_hyperparameters,
)
model_path = models_path / run_id
save_model(model, current_hyperparameters, model_path)
del model
except KeyboardInterrupt as e:
logging.info("Interrupted, stopping")
break
except Exception as e:
logging.error(
f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
)