Fix training loop OOMs
This commit is contained in:
parent
fb07fc674e
commit
edd92593b7
3 changed files with 54 additions and 27 deletions
|
|
@ -4,9 +4,9 @@ from pathlib import Path
|
||||||
# DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*"))
|
# DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/downloaded-unsplash").glob("*"))
|
||||||
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*"))
|
DATA = sorted(Path("/mnt/wsl/PHYSICALDRIVE0p1/featured").glob("*"))
|
||||||
|
|
||||||
TRAIN_SIZE = 0.9
|
TRAIN_SIZE = 0.95
|
||||||
|
|
||||||
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache3")
|
CACHE_PATH = Path("/mnt/wsl/PHYSICALDRIVE2/data/cache")
|
||||||
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
|
MODELS_PATH = Path("/home/andras/projects/bipolaroid/saved_models")
|
||||||
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
|
LOGS_PATH = Path("/home/andras/projects/bipolaroid/logs")
|
||||||
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs")
|
RUNS_PATH = Path("/home/andras/projects/bipolaroid/runs")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,13 @@
|
||||||
"metadata": {}
|
"metadata": {}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2024-09-03 22:20:43,878\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
|
|
@ -20,21 +27,35 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"from utils import set_up_logging, get_device\n",
|
"import matplotlib\n",
|
||||||
"from training import train_with_ray_factory\n",
|
"\n",
|
||||||
|
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = (\n",
|
||||||
|
" \"expandable_segments:True\" # avoid fragmented CUDA memory\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"matplotlib.use(\n",
|
||||||
|
" \"agg\"\n",
|
||||||
|
") # avoid \"UserWarning: Starting a Matplotlib GUI outside of the main thread will likely fail\" warnings\n",
|
||||||
|
"\n",
|
||||||
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
"from config import LOGS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA\n",
|
||||||
"from ray import tune\n",
|
"from utils import set_up_logging\n",
|
||||||
"from ray.tune.schedulers import ASHAScheduler\n",
|
|
||||||
"import os\n",
|
|
||||||
"from ray.air import RunConfig\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"set_up_logging(LOGS_PATH)\n",
|
"set_up_logging(LOGS_PATH)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from utils import get_device\n",
|
||||||
|
"from training import train_with_ray_factory\n",
|
||||||
|
"from ray import tune\n",
|
||||||
|
"import ray\n",
|
||||||
|
"from ray.tune.schedulers import ASHAScheduler\n",
|
||||||
|
"from ray.air import RunConfig\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"TRIAL_COUNT = 100\n",
|
"TRIAL_COUNT = 100\n",
|
||||||
"CHUNK_COUNT = 40\n",
|
"CHUNK_COUNT = 40\n",
|
||||||
"EPOCH_COUNT = 2\n",
|
"EPOCH_COUNT = 2\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
|
"ray.init(include_dashboard=True, dashboard_host=\"0.0.0.0\")\n",
|
||||||
|
"\n",
|
||||||
"device = get_device()\n",
|
"device = get_device()\n",
|
||||||
"f\"Using device {device}\""
|
"f\"Using device {device}\""
|
||||||
]
|
]
|
||||||
|
|
@ -173,7 +194,6 @@
|
||||||
" \"use_elu\": tune.choice([True, False]),\n",
|
" \"use_elu\": tune.choice([True, False]),\n",
|
||||||
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
|
" \"leaky_relu_alpha\": tune.uniform(0, 0.05),\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"scheduler = ASHAScheduler(max_t=config[\"chunk_count\"], grace_period=2)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"tuner = tune.Tuner(\n",
|
"tuner = tune.Tuner(\n",
|
||||||
" tune.with_resources(\n",
|
" tune.with_resources(\n",
|
||||||
|
|
@ -191,7 +211,9 @@
|
||||||
" tune_config=tune.TuneConfig(\n",
|
" tune_config=tune.TuneConfig(\n",
|
||||||
" metric=\"chunk_test_loss\",\n",
|
" metric=\"chunk_test_loss\",\n",
|
||||||
" mode=\"min\",\n",
|
" mode=\"min\",\n",
|
||||||
" scheduler=scheduler,\n",
|
" scheduler=ASHAScheduler(\n",
|
||||||
|
" max_t=config[\"chunk_count\"], grace_period=2, time_attr=\"chunk_id\"\n",
|
||||||
|
" ),\n",
|
||||||
" num_samples=TRIAL_COUNT,\n",
|
" num_samples=TRIAL_COUNT,\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" param_space=config,\n",
|
" param_space=config,\n",
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,9 @@ import torch
|
||||||
from .get_data_loader import get_data_loader
|
from .get_data_loader import get_data_loader
|
||||||
from ray import train
|
from ray import train
|
||||||
from ray.train import Checkpoint
|
from ray.train import Checkpoint
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from more_itertools import distribute
|
import logging
|
||||||
|
from more_itertools import divide
|
||||||
|
|
||||||
EPSILON = 1e-5
|
EPSILON = 1e-5
|
||||||
|
|
||||||
|
|
@ -24,18 +23,13 @@ def train_with_ray_factory(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
log_dir: Path,
|
log_dir: Path,
|
||||||
):
|
):
|
||||||
|
|
||||||
def train_with_ray(hyperparameters: Dict[str, Any]):
|
def train_with_ray(hyperparameters: Dict[str, Any]):
|
||||||
def inner(
|
def inner(
|
||||||
hyperparameters: Dict[str, Any],
|
hyperparameters: Dict[str, Any],
|
||||||
chunk_count: int,
|
chunk_count: int,
|
||||||
**_,
|
**_,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
train_data_loader = get_data_loader(train_data_paths, **hyperparameters)
|
|
||||||
test_data_loader = get_data_loader(
|
|
||||||
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
|
||||||
)
|
|
||||||
examples = next(iter(test_data_loader))
|
|
||||||
|
|
||||||
model, optimizer, scheduler, start_chunk_id, run_name = (
|
model, optimizer, scheduler, start_chunk_id, run_name = (
|
||||||
load_or_create_state(
|
load_or_create_state(
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -45,12 +39,25 @@ def train_with_ray_factory(
|
||||||
)
|
)
|
||||||
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)
|
||||||
|
|
||||||
|
train_data_loaders = [
|
||||||
|
get_data_loader(paths, **hyperparameters)
|
||||||
|
for paths in list(divide(chunk_count, train_data_paths))[
|
||||||
|
start_chunk_id:-1
|
||||||
|
]
|
||||||
|
]
|
||||||
|
test_data_loader = get_data_loader(
|
||||||
|
test_data_paths, **{**hyperparameters, "edit_count": 1}
|
||||||
|
)
|
||||||
|
examples = next(iter(test_data_loader))
|
||||||
|
|
||||||
with SummaryWriter(log_dir=log_dir / run_name) as writer:
|
with SummaryWriter(log_dir=log_dir / run_name) as writer:
|
||||||
writer.add_graph(model, examples[0].to(device))
|
writer.add_graph(model, examples[0].to(device))
|
||||||
for chunk_id, chunk in enumerate(
|
logging.info(f"Starting training with {run_name}")
|
||||||
distribute(chunk_count, train_data_loader)[start_chunk_id:-1],
|
for chunk_id, train_data_loader in enumerate(
|
||||||
|
train_data_loaders,
|
||||||
start=start_chunk_id,
|
start=start_chunk_id,
|
||||||
):
|
):
|
||||||
|
logging.info(f"Starting chunk {chunk_id}")
|
||||||
chunk_training_loss = 0
|
chunk_training_loss = 0
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
"Actual learning rate",
|
"Actual learning rate",
|
||||||
|
|
@ -58,12 +65,9 @@ def train_with_ray_factory(
|
||||||
chunk_id,
|
chunk_id,
|
||||||
)
|
)
|
||||||
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
for batch_id, (edited_histogram, original_histogram) in enumerate(
|
||||||
chunk
|
train_data_loader
|
||||||
):
|
):
|
||||||
global_step = (
|
global_step = chunk_id * len(train_data_loader) + batch_id
|
||||||
chunk_id * (len(train_data_loader) // chunk_count)
|
|
||||||
+ batch_id
|
|
||||||
)
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
predicted_original = model(edited_histogram.to(device))
|
predicted_original = model(edited_histogram.to(device))
|
||||||
loss = loss_function(
|
loss = loss_function(
|
||||||
|
|
@ -127,6 +131,7 @@ def train_with_ray_factory(
|
||||||
{
|
{
|
||||||
"chunk_test_loss": chunk_test_loss,
|
"chunk_test_loss": chunk_test_loss,
|
||||||
"chunk_training_loss": chunk_training_loss,
|
"chunk_training_loss": chunk_training_loss,
|
||||||
|
"chunk_id": chunk_id,
|
||||||
},
|
},
|
||||||
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
|
checkpoint=Checkpoint.from_directory(temp_checkpoint_dir),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue