Fix training loop OOMs

This commit is contained in:
Andras Schmelczer 2024-09-03 22:34:08 +01:00
parent fb07fc674e
commit edd92593b7
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
3 changed files with 54 additions and 27 deletions

View file

@ -4,9 +4,9 @@ from pathlib import Path
# DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*")) # DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*"))
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*")) DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*"))
TRAIN_SIZE = 0.9 TRAIN_SIZE = 0.95
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache3") CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache")
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models") MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs") LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs") RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs")

View file

@ -7,6 +7,13 @@
"metadata": {} "metadata": {}
}, },
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
]
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
@ -20,21 +27,35 @@
], ],
"source": [ "source": [
"import os\n", "import os\n",
"from utils import set_up_logging, get_device\n", "import matplotlib\n",
"from training import train_with_ray_factory\n", "\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n",
" \"expandable_segments:True\" # avoid fragmented CUDA memory\n",
")\n",
"\n",
"matplotlib.use(\n",
" \"agg\"\n",
") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n",
"\n",
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n", "from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
"from ray import tune\n", "from utils import set_up_logging\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"import os\n",
"from ray.air import RunConfig\n",
"\n", "\n",
"set_up_logging(LOGS_PATH)\n", "set_up_logging(LOGS_PATH)\n",
"\n", "\n",
"from utils import get_device\n",
"from training import train_with_ray_factory\n",
"from ray import tune\n",
"import ray\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"from ray.air import RunConfig\n",
"\n",
"\n",
"TRIAL_COUNT = 100\n", "TRIAL_COUNT = 100\n",
"CHUNK_COUNT = 40\n", "CHUNK_COUNT = 40\n",
"EPOCH_COUNT = 2\n", "EPOCH_COUNT = 2\n",
"\n", "\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n",
"\n",
"device = get_device()\n", "device = get_device()\n",
"f\"Using device {device}\"" "f\"Using device {device}\""
] ]
@ -173,7 +194,6 @@
" \"use_elu\": tune.choice([True, False]),\n", " \"use_elu\": tune.choice([True, False]),\n",
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n", " \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
"}\n", "}\n",
"scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
"\n", "\n",
"tuner = tune.Tuner(\n", "tuner = tune.Tuner(\n",
" tune.with_resources(\n", " tune.with_resources(\n",
@ -191,7 +211,9 @@
" tune_config=tune.TuneConfig(\n", " tune_config=tune.TuneConfig(\n",
" metric=\"chunk_test_loss\",\n", " metric=\"chunk_test_loss\",\n",
" mode=\"min\",\n", " mode=\"min\",\n",
" scheduler=scheduler,\n", " scheduler=ASHAScheduler(\n",
" max_t=config[\"chunk_count\"], grace_period=2, time_attr=\"chunk_id\"\n",
" ),\n",
" num_samples=TRIAL_COUNT,\n", " num_samples=TRIAL_COUNT,\n",
" ),\n", " ),\n",
" param_space=config,\n", " param_space=config,\n",

View file

@ -10,10 +10,9 @@ import torch
from .get_data_loader import get_data_loader from .get_data_loader import get_data_loader
from ray import train from ray import train
from ray.train import Checkpoint from ray.train import Checkpoint
import os
import tempfile import tempfile
from more_itertools import distribute import logging
from more_itertools import divide
EPSILON = 1e-5 EPSILON = 1e-5
@ -24,18 +23,13 @@ def train_with_ray_factory(
device: torch.device, device: torch.device,
log_dir: Path, log_dir: Path,
): ):
def train_with_ray(hyperparameters: Dict[str, Any]): def train_with_ray(hyperparameters: Dict[str, Any]):
def inner( def inner(
hyperparameters: Dict[str, Any], hyperparameters: Dict[str, Any],
chunk_count: int, chunk_count: int,
**_, **_,
) -> torch.nn.Module: ) -> torch.nn.Module:
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
test_data_loader = get_data_loader(
test_data_paths, **{**hyperparameters, "edit_count": 1}
)
examples = next(iter(test_data_loader))
model, optimizer, scheduler, start_chunk_id, run_name = ( model, optimizer, scheduler, start_chunk_id, run_name = (
load_or_create_state( load_or_create_state(
device=device, device=device,
@ -45,12 +39,25 @@ def train_with_ray_factory(
) )
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device) loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
train_data_loaders = [
get_data_loader(paths, **hyperparameters)
for paths in list(divide(chunk_count, train_data_paths))[
start_chunk_id:-1
]
]
test_data_loader = get_data_loader(
test_data_paths, **{**hyperparameters, "edit_count": 1}
)
examples = next(iter(test_data_loader))
with SummaryWriter(log_dir=log_dir / run_name) as writer: with SummaryWriter(log_dir=log_dir / run_name) as writer:
writer.add_graph(model, examples[0].to(device)) writer.add_graph(model, examples[0].to(device))
for chunk_id, chunk in enumerate( logging.info(f"Starting training with {run_name}")
distribute(chunk_count, train_data_loader)[start_chunk_id:-1], for chunk_id, train_data_loader in enumerate(
train_data_loaders,
start=start_chunk_id, start=start_chunk_id,
): ):
logging.info(f"Starting chunk {chunk_id}")
chunk_training_loss = 0 chunk_training_loss = 0
writer.add_scalar( writer.add_scalar(
"Actual learning rate", "Actual learning rate",
@ -58,12 +65,9 @@ def train_with_ray_factory(
chunk_id, chunk_id,
) )
for batch_id, (edited_histogram, original_histogram) in enumerate( for batch_id, (edited_histogram, original_histogram) in enumerate(
chunk train_data_loader
): ):
global_step = ( global_step = chunk_id * len(train_data_loader) + batch_id
chunk_id * (len(train_data_loader) // chunk_count)
+ batch_id
)
optimizer.zero_grad() optimizer.zero_grad()
predicted_original = model(edited_histogram.to(device)) predicted_original = model(edited_histogram.to(device))
loss = loss_function( loss = loss_function(
@ -127,6 +131,7 @@ def train_with_ray_factory(
{ {
"chunk_test_loss": chunk_test_loss, "chunk_test_loss": chunk_test_loss,
"chunk_training_loss": chunk_training_loss, "chunk_training_loss": chunk_training_loss,
"chunk_id": chunk_id,
}, },
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir), checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
) )