You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

198 lines
4.9 KiB

#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import json
import os
import torch
model_params_subdir = "ModelParameters"
optimizer_params_subdir = "OptimizerParameters"
latent_codes_subdir = "LatentCodes"
logs_filename = "Logs.pth"
reconstructions_subdir = "Reconstructions"
reconstruction_meshes_subdir = "Meshes"
reconstruction_codes_subdir = "Codes"
specifications_filename = "specs.json"
data_source_map_filename = ".datasources.json"
evaluation_subdir = "Evaluation"
sdf_samples_subdir = "SdfSamples"
surface_samples_subdir = "SurfaceSamples"
normalization_param_subdir = "NormalizationParameters"
training_meshes_subdir = "TrainingMeshes"
def load_experiment_specifications(experiment_directory):
filename = os.path.join(experiment_directory, specifications_filename)
if not os.path.isfile(filename):
raise Exception(
"The experiment directory ({}) does not include specifications file "
+ '"specs.json"'.format(experiment_directory)
)
return json.load(open(filename))
def load_model_parameters(experiment_directory, checkpoint, decoder):
filename = os.path.join(
experiment_directory, model_params_subdir, checkpoint + ".pth"
)
if not os.path.isfile(filename):
raise Exception('model state dict "{}" does not exist'.format(filename))
data = torch.load(filename)
decoder.load_state_dict(data["model_state_dict"])
return data["epoch"]
def build_decoder(experiment_directory, experiment_specs):
arch = __import__(
"networks." + experiment_specs["NetworkArch"], fromlist=["Decoder"]
)
latent_size = experiment_specs["CodeLength"]
decoder = arch.Decoder(latent_size, **experiment_specs["NetworkSpecs"]).cuda()
return decoder
def load_decoder(
experiment_directory, experiment_specs, checkpoint, data_parallel=True
):
decoder = build_decoder(experiment_directory, experiment_specs)
if data_parallel:
decoder = torch.nn.DataParallel(decoder)
epoch = load_model_parameters(experiment_directory, checkpoint, decoder)
return (decoder, epoch)
def load_latent_vectors(experiment_directory, checkpoint):
filename = os.path.join(
experiment_directory, latent_codes_subdir, checkpoint + ".pth"
)
if not os.path.isfile(filename):
raise Exception(
"The experiment directory ({}) does not include a latent code file"
+ " for checkpoint '{}'".format(experiment_directory, checkpoint)
)
data = torch.load(filename)
if isinstance(data["latent_codes"], torch.Tensor):
num_vecs = data["latent_codes"].size()[0]
lat_vecs = []
for i in range(num_vecs):
lat_vecs.append(data["latent_codes"][i].cuda())
return lat_vecs
else:
num_embeddings, embedding_dim = data["latent_codes"]["weight"].shape
lat_vecs = torch.nn.Embedding(num_embeddings, embedding_dim)
lat_vecs.load_state_dict(data["latent_codes"])
return lat_vecs.weight.data.detach()
def get_data_source_map_filename(data_dir):
return os.path.join(data_dir, data_source_map_filename)
def get_reconstructed_mesh_filename(
experiment_dir, epoch, dataset, class_name, instance_name
):
return os.path.join(
experiment_dir,
reconstructions_subdir,
str(epoch),
reconstruction_meshes_subdir,
dataset,
class_name,
instance_name + ".ply",
)
def get_reconstructed_code_filename(
experiment_dir, epoch, dataset, class_name, instance_name
):
return os.path.join(
experiment_dir,
reconstructions_subdir,
str(epoch),
reconstruction_codes_subdir,
dataset,
class_name,
instance_name + ".pth",
)
def get_evaluation_dir(experiment_dir, checkpoint, create_if_nonexistent=False):
dir = os.path.join(experiment_dir, evaluation_subdir, checkpoint)
if create_if_nonexistent and not os.path.isdir(dir):
os.makedirs(dir)
return dir
def get_model_params_dir(experiment_dir, create_if_nonexistent=False):
dir = os.path.join(experiment_dir, model_params_subdir)
if create_if_nonexistent and not os.path.isdir(dir):
os.makedirs(dir)
return dir
def get_optimizer_params_dir(experiment_dir, create_if_nonexistent=False):
dir = os.path.join(experiment_dir, optimizer_params_subdir)
if create_if_nonexistent and not os.path.isdir(dir):
os.makedirs(dir)
return dir
def get_latent_codes_dir(experiment_dir, create_if_nonexistent=False):
dir = os.path.join(experiment_dir, latent_codes_subdir)
if create_if_nonexistent and not os.path.isdir(dir):
os.makedirs(dir)
return dir
def get_normalization_params_filename(
data_dir, dataset_name, class_name, instance_name
):
return os.path.join(
data_dir,
normalization_param_subdir,
dataset_name,
class_name,
instance_name + ".npz",
)