You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
903 B
36 lines
903 B
3 years ago
|
import torch
|
||
|
import numpy as np
|
||
|
import os
|
||
|
|
||
|
|
||
|
def mkdir(path):
|
||
|
if not os.path.exists(path):
|
||
|
os.makedirs(path)
|
||
|
|
||
|
def mkdirs(*paths):
|
||
|
# print(paths)
|
||
|
if isinstance(paths, list) or isinstance(paths, tuple):
|
||
|
for path in paths:
|
||
|
mkdir(path)
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
|
||
|
def to_numpy(input):
|
||
|
if isinstance(input, torch.Tensor):
|
||
|
return input.detach().cpu().numpy()
|
||
|
elif isinstance(input, np.ndarray):
|
||
|
return input
|
||
|
else:
|
||
|
raise TypeError('Unknown type of input, expected torch.Tensor or '\
|
||
|
'np.ndarray, but got {}'.format(type(input)))
|
||
|
|
||
|
|
||
|
def module_size(module):
|
||
|
assert isinstance(module, torch.nn.Module)
|
||
|
n_params, n_conv_layers = 0, 0
|
||
|
for name, param in module.named_parameters():
|
||
|
if 'conv' in name:
|
||
|
n_conv_layers += 1
|
||
|
n_params += param.numel()
|
||
|
return n_params, n_conv_layers
|