-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
32 lines (22 loc) · 872 Bytes
/
utils.py
File metadata and controls
32 lines (22 loc) · 872 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import json
from pathlib import Path
from typing import Dict, Tuple
import torch
from torch.nn import Module as NNModule
from exceptions import ImageTrainerError
def get_last_child_module(model: NNModule) -> Tuple[str, NNModule]:
return list(model.named_children())[-1]
def get_device(gpu: bool = False) -> torch.device:
"""Get the torch device, try to get a GPU if gpu is set, raise an error if no GPU and gpu is set"""
if gpu:
if torch.cuda.is_available():
return torch.device("cuda")
else:
raise ImageTrainerError(
f"No GPU available. Run without -g/--gpu argument"
)
torch.device("cpu")
def parse_json_file(cat_name_path: Path) -> Dict:
return json.loads(cat_name_path.read_text())
def invert_dict(the_dict: Dict) -> Dict:
return {v: k for k, v in the_dict.items()}