diff --git a/src/config.py b/src/config.py index 779d7d8..af21908 100644 --- a/src/config.py +++ b/src/config.py @@ -4,9 +4,9 @@ from pathlib import Path # DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*")) DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*")) -TRAIN_SIZE = 0.9 +TRAIN_SIZE = 0.95 -CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache3") +CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache") MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models") LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs") RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs") diff --git a/src/train.ipynb b/src/train.ipynb index c424838..d2acf1f 100644 --- a/src/train.ipynb +++ b/src/train.ipynb @@ -7,6 +7,13 @@ "metadata": {} }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" + ] + }, { "data": { "text/plain": [ @@ -20,21 +27,35 @@ ], "source": [ "import os\n", - "from utils import set_up_logging, get_device\n", - "from training import train_with_ray_factory\n", + "import matplotlib\n", + "\n", + "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n", + " \"expandable_segments:True\" # avoid fragmented CUDA memory\n", + ")\n", + "\n", + "matplotlib.use(\n", + " \"agg\"\n", + ") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n", + "\n", "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", - "from ray import tune\n", - "from ray.tune.schedulers import ASHAScheduler\n", - "import os\n", - "from ray.air import RunConfig\n", + "from utils import set_up_logging\n", "\n", "set_up_logging(LOGS_PATH)\n", "\n", + "from utils import get_device\n", + "from training import train_with_ray_factory\n", + "from ray import tune\n", + "import ray\n", + "from ray.tune.schedulers import ASHAScheduler\n", + "from ray.air import RunConfig\n", + "\n", + "\n", "TRIAL_COUNT = 100\n", "CHUNK_COUNT = 40\n", "EPOCH_COUNT = 2\n", "\n", - "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", + "ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n", + "\n", "device = get_device()\n", "f\"Using device {device}\"" ] @@ -173,7 +194,6 @@ " \"use_elu\": tune.choice([True, False]),\n", " \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n", "}\n", - "scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n", "\n", "tuner = tune.Tuner(\n", " tune.with_resources(\n", @@ -191,7 +211,9 @@ " tune_config=tune.TuneConfig(\n", " metric=\"chunk_test_loss\",\n", " mode=\"min\",\n", - " scheduler=scheduler,\n", + " scheduler=ASHAScheduler(\n", + " max_t=config[\"chunk_count\"], grace_period=2, time_attr=\"chunk_id\"\n", + " ),\n", " num_samples=TRIAL_COUNT,\n", " ),\n", " param_space=config,\n", diff --git a/src/training/train_with_ray.py b/src/training/train_with_ray.py index 5ddcc1d..d7a3f46 100644 --- a/src/training/train_with_ray.py +++ b/src/training/train_with_ray.py @@ -10,10 +10,9 @@ import torch from .get_data_loader import get_data_loader from ray import train from ray.train import Checkpoint -import os import tempfile -from more_itertools import distribute - +import logging +from more_itertools import divide EPSILON = 1e-5 @@ -24,18 +23,13 @@ def train_with_ray_factory( device: torch.device, log_dir: Path, ): + def train_with_ray(hyperparameters: Dict[str, Any]): def inner( hyperparameters: Dict[str, Any], chunk_count: int, **_, ) -> torch.nn.Module: - train_data_loader = get_data_loader(train_data_paths, **hyperparameters) - test_data_loader = get_data_loader( - test_data_paths, **{**hyperparameters, "edit_count": 1} - ) - examples = next(iter(test_data_loader)) - model, optimizer, scheduler, start_chunk_id, run_name = ( load_or_create_state( device=device, @@ -45,12 +39,25 @@ def train_with_ray_factory( ) loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) + train_data_loaders = [ + get_data_loader(paths, **hyperparameters) + for paths in list(divide(chunk_count, train_data_paths))[ + start_chunk_id:-1 + ] + ] + test_data_loader = get_data_loader( + test_data_paths, **{**hyperparameters, "edit_count": 1} + ) + examples = next(iter(test_data_loader)) + with SummaryWriter(log_dir=log_dir / run_name) as writer: writer.add_graph(model, examples[0].to(device)) - for chunk_id, chunk in enumerate( - distribute(chunk_count, train_data_loader)[start_chunk_id:-1], + logging.info(f"Starting training with {run_name}") + for chunk_id, train_data_loader in enumerate( + train_data_loaders, start=start_chunk_id, ): + logging.info(f"Starting chunk {chunk_id}") chunk_training_loss = 0 writer.add_scalar( "Actual learning rate", @@ -58,12 +65,9 @@ def train_with_ray_factory( chunk_id, ) for batch_id, (edited_histogram, original_histogram) in enumerate( - chunk + train_data_loader ): - global_step = ( - chunk_id * (len(train_data_loader) // chunk_count) - + batch_id - ) + global_step = chunk_id * len(train_data_loader) + batch_id optimizer.zero_grad() predicted_original = model(edited_histogram.to(device)) loss = loss_function( @@ -127,6 +131,7 @@ def train_with_ray_factory( { "chunk_test_loss": chunk_test_loss, "chunk_training_loss": chunk_training_loss, + "chunk_id": chunk_id, }, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir), )