Fix training loop OOMs
This commit is contained in:
parent
fb07fc674e
commit
edd92593b7
3 changed files with 54 additions and 27 deletions
|
|
@ -4,9 +4,9 @@ from pathlib import Path
|
|||
# DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*"))
|
||||
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*"))
|
||||
|
||||
TRAIN_SIZE = 0.9
|
||||
TRAIN_SIZE = 0.95
|
||||
|
||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache3")
|
||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache")
|
||||
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
|
||||
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
|
||||
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,13 @@
|
|||
"metadata": {}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
|
|
@ -20,21 +27,35 @@
|
|||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from utils import set_up_logging, get_device\n",
|
||||
"from training import train_with_ray_factory\n",
|
||||
"import matplotlib\n",
|
||||
"\n",
|
||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n",
|
||||
" \"expandable_segments:True\" # avoid fragmented CUDA memory\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"matplotlib.use(\n",
|
||||
" \"agg\"\n",
|
||||
") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n",
|
||||
"\n",
|
||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
||||
"from ray import tune\n",
|
||||
"from ray.tune.schedulers import ASHAScheduler\n",
|
||||
"import os\n",
|
||||
"from ray.air import RunConfig\n",
|
||||
"from utils import set_up_logging\n",
|
||||
"\n",
|
||||
"set_up_logging(LOGS_PATH)\n",
|
||||
"\n",
|
||||
"from utils import get_device\n",
|
||||
"from training import train_with_ray_factory\n",
|
||||
"from ray import tune\n",
|
||||
"import ray\n",
|
||||
"from ray.tune.schedulers import ASHAScheduler\n",
|
||||
"from ray.air import RunConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"TRIAL_COUNT = 100\n",
|
||||
"CHUNK_COUNT = 40\n",
|
||||
"EPOCH_COUNT = 2\n",
|
||||
"\n",
|
||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
|
||||
"ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n",
|
||||
"\n",
|
||||
"device = get_device()\n",
|
||||
"f\"Using device {device}\""
|
||||
]
|
||||
|
|
@ -173,7 +194,6 @@
|
|||
" \"use_elu\": tune.choice([True, False]),\n",
|
||||
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
|
||||
"}\n",
|
||||
"scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
|
||||
"\n",
|
||||
"tuner = tune.Tuner(\n",
|
||||
" tune.with_resources(\n",
|
||||
|
|
@ -191,7 +211,9 @@
|
|||
" tune_config=tune.TuneConfig(\n",
|
||||
" metric=\"chunk_test_loss\",\n",
|
||||
" mode=\"min\",\n",
|
||||
" scheduler=scheduler,\n",
|
||||
" scheduler=ASHAScheduler(\n",
|
||||
" max_t=config[\"chunk_count\"], grace_period=2, time_attr=\"chunk_id\"\n",
|
||||
" ),\n",
|
||||
" num_samples=TRIAL_COUNT,\n",
|
||||
" ),\n",
|
||||
" param_space=config,\n",
|
||||
|
|
|
|||
|
|
@ -10,10 +10,9 @@ import torch
|
|||
from .get_data_loader import get_data_loader
|
||||
from ray import train
|
||||
from ray.train import Checkpoint
|
||||
import os
|
||||
import tempfile
|
||||
from more_itertools import distribute
|
||||
|
||||
import logging
|
||||
from more_itertools import divide
|
||||
|
||||
EPSILON = 1e-5
|
||||
|
||||
|
|
@ -24,18 +23,13 @@ def train_with_ray_factory(
|
|||
device: torch.device,
|
||||
log_dir: Path,
|
||||
):
|
||||
|
||||
def train_with_ray(hyperparameters: Dict[str, Any]):
|
||||
def inner(
|
||||
hyperparameters: Dict[str, Any],
|
||||
chunk_count: int,
|
||||
**_,
|
||||
) -> torch.nn.Module:
|
||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
||||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
||||
)
|
||||
examples = next(iter(test_data_loader))
|
||||
|
||||
model, optimizer, scheduler, start_chunk_id, run_name = (
|
||||
load_or_create_state(
|
||||
device=device,
|
||||
|
|
@ -45,12 +39,25 @@ def train_with_ray_factory(
|
|||
)
|
||||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||
|
||||
train_data_loaders = [
|
||||
get_data_loader(paths, **hyperparameters)
|
||||
for paths in list(divide(chunk_count, train_data_paths))[
|
||||
start_chunk_id:-1
|
||||
]
|
||||
]
|
||||
test_data_loader = get_data_loader(
|
||||
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
||||
)
|
||||
examples = next(iter(test_data_loader))
|
||||
|
||||
with SummaryWriter(log_dir=log_dir / run_name) as writer:
|
||||
writer.add_graph(model, examples[0].to(device))
|
||||
for chunk_id, chunk in enumerate(
|
||||
distribute(chunk_count, train_data_loader)[start_chunk_id:-1],
|
||||
logging.info(f"Starting training with {run_name}")
|
||||
for chunk_id, train_data_loader in enumerate(
|
||||
train_data_loaders,
|
||||
start=start_chunk_id,
|
||||
):
|
||||
logging.info(f"Starting chunk {chunk_id}")
|
||||
chunk_training_loss = 0
|
||||
writer.add_scalar(
|
||||
"Actual learning rate",
|
||||
|
|
@ -58,12 +65,9 @@ def train_with_ray_factory(
|
|||
chunk_id,
|
||||
)
|
||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||
chunk
|
||||
train_data_loader
|
||||
):
|
||||
global_step = (
|
||||
chunk_id * (len(train_data_loader) // chunk_count)
|
||||
+ batch_id
|
||||
)
|
||||
global_step = chunk_id * len(train_data_loader) + batch_id
|
||||
optimizer.zero_grad()
|
||||
predicted_original = model(edited_histogram.to(device))
|
||||
loss = loss_function(
|
||||
|
|
@ -127,6 +131,7 @@ def train_with_ray_factory(
|
|||
{
|
||||
"chunk_test_loss": chunk_test_loss,
|
||||
"chunk_training_loss": chunk_training_loss,
|
||||
"chunk_id": chunk_id,
|
||||
},
|
||||
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue