Add get_device
This commit is contained in:
parent
d2ce23fa33
commit
10249e6d26
2 changed files with 8 additions and 0 deletions
|
|
@ -4,3 +4,4 @@ from .generate_rotation_matrices import generate_rotation_matrices
|
|||
from .kldiv import kldiv
|
||||
from .set_up_logging import set_up_logging
|
||||
from .serialise_hparams import serialise_hparams
|
||||
from .get_device import get_device
|
||||
|
|
|
|||
7
src/utils/get_device.py
Normal file
7
src/utils/get_device.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def get_device() -> torch.device:
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
Loading…
Add table
Add a link
Reference in a new issue