You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.2 KiB
40 lines
1.2 KiB
3 years ago
|
"""
|
||
|
Load args and model from a directory
|
||
|
"""
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader, TensorDataset
|
||
|
from argparse import Namespace
|
||
|
import h5py
|
||
|
import json
|
||
|
|
||
|
|
||
|
def load_args(run_dir):
|
||
|
with open(run_dir + '/args.txt') as args_file:
|
||
|
args = Namespace(**json.load(args_file))
|
||
|
# pprint(args)
|
||
|
return args
|
||
|
|
||
|
|
||
|
def load_data(hdf5_file, ndata, batch_size, only_input=True, return_stats=False):
|
||
|
with h5py.File(hdf5_file, 'r') as f:
|
||
|
# x_data = f['input'][:ndata]
|
||
|
x_data = f['input'][:ndata,:,:20,:40]
|
||
|
print(f'x_data: {x_data.shape}')
|
||
|
if not only_input:
|
||
|
y_data = f['output'][:ndata]
|
||
|
print(f'y_data: {y_data.shape}')
|
||
|
|
||
|
stats = {}
|
||
|
if return_stats:
|
||
|
y_variation = ((y_data - y_data.mean(0, keepdims=True)) ** 2).sum(
|
||
|
axis=(0, 2, 3))
|
||
|
stats['y_variation'] = y_variation
|
||
|
|
||
|
data_tuple = (torch.FloatTensor(x_data), ) if only_input else (
|
||
|
torch.FloatTensor(x_data), torch.FloatTensor(y_data))
|
||
|
data_loader = DataLoader(TensorDataset(*data_tuple),
|
||
|
batch_size=batch_size, shuffle=True, drop_last=True)
|
||
|
print(f'Loaded dataset: {hdf5_file}')
|
||
|
return data_loader, stats
|
||
|
|