Add get_device

This commit is contained in:
Andras Schmelczer 2024-06-30 16:15:34 +01:00
parent d2ce23fa33
commit 10249e6d26
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 8 additions and 0 deletions

View file

@ -4,3 +4,4 @@ from .generate_rotation_matrices import generate_rotation_matrices
from .kldiv import kldiv
from .set_up_logging import set_up_logging
from .serialise_hparams import serialise_hparams
from .get_device import get_device

7
src/utils/get_device.py Normal file
View file

@ -0,0 +1,7 @@
import torch
import os
def get_device() -> torch.device:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")