You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.2 KiB

3 years ago
"""
Load args and model from a directory
"""
import torch
from torch.utils.data import DataLoader, TensorDataset
from argparse import Namespace
import h5py
import json
def load_args(run_dir):
with open(run_dir + '/args.txt') as args_file:
args = Namespace(**json.load(args_file))
# pprint(args)
return args
def load_data(hdf5_file, ndata, batch_size, only_input=True, return_stats=False):
with h5py.File(hdf5_file, 'r') as f:
# x_data = f['input'][:ndata]
x_data = f['input'][:ndata,:,:20,:40]
print(f'x_data: {x_data.shape}')
if not only_input:
y_data = f['output'][:ndata]
print(f'y_data: {y_data.shape}')
stats = {}
if return_stats:
y_variation = ((y_data - y_data.mean(0, keepdims=True)) ** 2).sum(
axis=(0, 2, 3))
stats['y_variation'] = y_variation
data_tuple = (torch.FloatTensor(x_data), ) if only_input else (
torch.FloatTensor(x_data), torch.FloatTensor(y_data))
data_loader = DataLoader(TensorDataset(*data_tuple),
batch_size=batch_size, shuffle=True, drop_last=True)
print(f'Loaded dataset: {hdf5_file}')
return data_loader, stats