diff --git a/src/utils/__init__.py b/src/utils/__init__.py index de54983..54ec8fc 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -4,3 +4,4 @@ from .generate_rotation_matrices import generate_rotation_matrices from .kldiv import kldiv from .set_up_logging import set_up_logging from .serialise_hparams import serialise_hparams +from .get_device import get_device diff --git a/src/utils/get_device.py b/src/utils/get_device.py new file mode 100644 index 0000000..029b5ee --- /dev/null +++ b/src/utils/get_device.py @@ -0,0 +1,7 @@ +import torch +import os + + +def get_device() -> torch.device: + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")