You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

36 lines
903 B

import torch
import numpy as np
import os
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(*paths):
# print(paths)
if isinstance(paths, list) or isinstance(paths, tuple):
for path in paths:
mkdir(path)
else:
raise ValueError
def to_numpy(input):
if isinstance(input, torch.Tensor):
return input.detach().cpu().numpy()
elif isinstance(input, np.ndarray):
return input
else:
raise TypeError('Unknown type of input, expected torch.Tensor or '\
'np.ndarray, but got {}'.format(type(input)))
def module_size(module):
assert isinstance(module, torch.nn.Module)
n_params, n_conv_layers = 0, 0
for name, param in module.named_parameters():
if 'conv' in name:
n_conv_layers += 1
n_params += param.numel()
return n_params, n_conv_layers