From 10249e6d26e5e6c6061f88c9d33c65eeef8f3b5c Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 30 Jun 2024 16:15:34 +0100 Subject: [PATCH] Add get_device --- src/utils/__init__.py | 1 + src/utils/get_device.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 src/utils/get_device.py diff --git a/src/utils/__init__.py b/src/utils/__init__.py index de54983..54ec8fc 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -4,3 +4,4 @@ from .generate_rotation_matrices import generate_rotation_matrices from .kldiv import kldiv from .set_up_logging import set_up_logging from .serialise_hparams import serialise_hparams +from .get_device import get_device diff --git a/src/utils/get_device.py b/src/utils/get_device.py new file mode 100644 index 0000000..029b5ee --- /dev/null +++ b/src/utils/get_device.py @@ -0,0 +1,7 @@ +import torch +import os + + +def get_device() -> torch.device: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")