You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
591 lines
17 KiB
591 lines
17 KiB
#!/usr/bin/env python3
|
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
import torch
|
|
import torch.utils.data as data_utils
|
|
import signal
|
|
import sys
|
|
import os
|
|
import logging
|
|
import math
|
|
import json
|
|
import time
|
|
|
|
import deep_sdf
|
|
import deep_sdf.workspace as ws
|
|
|
|
|
|
class LearningRateSchedule:
|
|
def get_learning_rate(self, epoch):
|
|
pass
|
|
|
|
|
|
class ConstantLearningRateSchedule(LearningRateSchedule):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def get_learning_rate(self, epoch):
|
|
return self.value
|
|
|
|
|
|
class StepLearningRateSchedule(LearningRateSchedule):
|
|
def __init__(self, initial, interval, factor):
|
|
self.initial = initial
|
|
self.interval = interval
|
|
self.factor = factor
|
|
|
|
def get_learning_rate(self, epoch):
|
|
|
|
return self.initial * (self.factor ** (epoch // self.interval))
|
|
|
|
|
|
class WarmupLearningRateSchedule(LearningRateSchedule):
|
|
def __init__(self, initial, warmed_up, length):
|
|
self.initial = initial
|
|
self.warmed_up = warmed_up
|
|
self.length = length
|
|
|
|
def get_learning_rate(self, epoch):
|
|
if epoch > self.length:
|
|
return self.warmed_up
|
|
return self.initial + (self.warmed_up - self.initial) * epoch / self.length
|
|
|
|
|
|
def get_learning_rate_schedules(specs):
|
|
|
|
schedule_specs = specs["LearningRateSchedule"]
|
|
|
|
schedules = []
|
|
|
|
for schedule_specs in schedule_specs:
|
|
|
|
if schedule_specs["Type"] == "Step":
|
|
schedules.append(
|
|
StepLearningRateSchedule(
|
|
schedule_specs["Initial"],
|
|
schedule_specs["Interval"],
|
|
schedule_specs["Factor"],
|
|
)
|
|
)
|
|
elif schedule_specs["Type"] == "Warmup":
|
|
schedules.append(
|
|
WarmupLearningRateSchedule(
|
|
schedule_specs["Initial"],
|
|
schedule_specs["Final"],
|
|
schedule_specs["Length"],
|
|
)
|
|
)
|
|
elif schedule_specs["Type"] == "Constant":
|
|
schedules.append(ConstantLearningRateSchedule(schedule_specs["Value"]))
|
|
|
|
else:
|
|
raise Exception(
|
|
'no known learning rate schedule of type "{}"'.format(
|
|
schedule_specs["Type"]
|
|
)
|
|
)
|
|
|
|
return schedules
|
|
|
|
|
|
def save_model(experiment_directory, filename, decoder, epoch):
|
|
|
|
model_params_dir = ws.get_model_params_dir(experiment_directory, True)
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "model_state_dict": decoder.state_dict()},
|
|
os.path.join(model_params_dir, filename),
|
|
)
|
|
|
|
|
|
def save_optimizer(experiment_directory, filename, optimizer, epoch):
|
|
|
|
optimizer_params_dir = ws.get_optimizer_params_dir(experiment_directory, True)
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "optimizer_state_dict": optimizer.state_dict()},
|
|
os.path.join(optimizer_params_dir, filename),
|
|
)
|
|
|
|
|
|
def load_optimizer(experiment_directory, filename, optimizer):
|
|
|
|
full_filename = os.path.join(
|
|
ws.get_optimizer_params_dir(experiment_directory), filename
|
|
)
|
|
|
|
if not os.path.isfile(full_filename):
|
|
raise Exception(
|
|
'optimizer state dict "{}" does not exist'.format(full_filename)
|
|
)
|
|
|
|
data = torch.load(full_filename)
|
|
|
|
optimizer.load_state_dict(data["optimizer_state_dict"])
|
|
|
|
return data["epoch"]
|
|
|
|
|
|
def save_latent_vectors(experiment_directory, filename, latent_vec, epoch):
|
|
|
|
latent_codes_dir = ws.get_latent_codes_dir(experiment_directory, True)
|
|
|
|
all_latents = latent_vec.state_dict()
|
|
|
|
torch.save(
|
|
{"epoch": epoch, "latent_codes": all_latents},
|
|
os.path.join(latent_codes_dir, filename),
|
|
)
|
|
|
|
|
|
# TODO: duplicated in workspace
|
|
def load_latent_vectors(experiment_directory, filename, lat_vecs):
|
|
|
|
full_filename = os.path.join(
|
|
ws.get_latent_codes_dir(experiment_directory), filename
|
|
)
|
|
|
|
if not os.path.isfile(full_filename):
|
|
raise Exception('latent state file "{}" does not exist'.format(full_filename))
|
|
|
|
data = torch.load(full_filename)
|
|
|
|
if isinstance(data["latent_codes"], torch.Tensor):
|
|
|
|
# for backwards compatibility
|
|
if not lat_vecs.num_embeddings == data["latent_codes"].size()[0]:
|
|
raise Exception(
|
|
"num latent codes mismatched: {} vs {}".format(
|
|
lat_vecs.num_embeddings, data["latent_codes"].size()[0]
|
|
)
|
|
)
|
|
|
|
if not lat_vecs.embedding_dim == data["latent_codes"].size()[2]:
|
|
raise Exception("latent code dimensionality mismatch")
|
|
|
|
for i, lat_vec in enumerate(data["latent_codes"]):
|
|
lat_vecs.weight.data[i, :] = lat_vec
|
|
|
|
else:
|
|
lat_vecs.load_state_dict(data["latent_codes"])
|
|
|
|
return data["epoch"]
|
|
|
|
|
|
def save_logs(
|
|
experiment_directory,
|
|
loss_log,
|
|
lr_log,
|
|
timing_log,
|
|
lat_mag_log,
|
|
param_mag_log,
|
|
epoch,
|
|
):
|
|
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"loss": loss_log,
|
|
"learning_rate": lr_log,
|
|
"timing": timing_log,
|
|
"latent_magnitude": lat_mag_log,
|
|
"param_magnitude": param_mag_log,
|
|
},
|
|
os.path.join(experiment_directory, ws.logs_filename),
|
|
)
|
|
|
|
|
|
def load_logs(experiment_directory):
|
|
|
|
full_filename = os.path.join(experiment_directory, ws.logs_filename)
|
|
|
|
if not os.path.isfile(full_filename):
|
|
raise Exception('log file "{}" does not exist'.format(full_filename))
|
|
|
|
data = torch.load(full_filename)
|
|
|
|
return (
|
|
data["loss"],
|
|
data["learning_rate"],
|
|
data["timing"],
|
|
data["latent_magnitude"],
|
|
data["param_magnitude"],
|
|
data["epoch"],
|
|
)
|
|
|
|
|
|
def clip_logs(loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, epoch):
|
|
|
|
iters_per_epoch = len(loss_log) // len(lr_log)
|
|
|
|
loss_log = loss_log[: (iters_per_epoch * epoch)]
|
|
lr_log = lr_log[:epoch]
|
|
timing_log = timing_log[:epoch]
|
|
lat_mag_log = lat_mag_log[:epoch]
|
|
for n in param_mag_log:
|
|
param_mag_log[n] = param_mag_log[n][:epoch]
|
|
|
|
return (loss_log, lr_log, timing_log, lat_mag_log, param_mag_log)
|
|
|
|
|
|
def get_spec_with_default(specs, key, default):
|
|
try:
|
|
return specs[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
|
|
def get_mean_latent_vector_magnitude(latent_vectors):
|
|
return torch.mean(torch.norm(latent_vectors.weight.data.detach(), dim=1))
|
|
|
|
|
|
def append_parameter_magnitudes(param_mag_log, model):
|
|
for name, param in model.named_parameters():
|
|
if len(name) > 7 and name[:7] == "module.":
|
|
name = name[7:]
|
|
if name not in param_mag_log.keys():
|
|
param_mag_log[name] = []
|
|
param_mag_log[name].append(param.data.norm().item())
|
|
|
|
|
|
def main_function(experiment_directory, continue_from, batch_split):
|
|
|
|
logging.debug("running " + experiment_directory)
|
|
|
|
specs = ws.load_experiment_specifications(experiment_directory)
|
|
|
|
logging.info("Experiment description: \n" + specs["Description"][0])
|
|
|
|
data_source = specs["DataSource"]
|
|
train_split_file = specs["TrainSplit"]
|
|
|
|
arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])
|
|
|
|
logging.debug(specs["NetworkSpecs"])
|
|
|
|
latent_size = specs["CodeLength"]
|
|
|
|
checkpoints = list(
|
|
range(
|
|
specs["SnapshotFrequency"],
|
|
specs["NumEpochs"] + 1,
|
|
specs["SnapshotFrequency"],
|
|
)
|
|
)
|
|
|
|
for checkpoint in specs["AdditionalSnapshots"]:
|
|
checkpoints.append(checkpoint)
|
|
checkpoints.sort()
|
|
|
|
lr_schedules = get_learning_rate_schedules(specs)
|
|
|
|
grad_clip = get_spec_with_default(specs, "GradientClipNorm", None)
|
|
if grad_clip is not None:
|
|
logging.debug("clipping gradients to max norm {}".format(grad_clip))
|
|
|
|
def save_latest(epoch):
|
|
|
|
save_model(experiment_directory, "latest.pth", decoder, epoch)
|
|
save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch)
|
|
save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch)
|
|
|
|
def save_checkpoints(epoch):
|
|
|
|
save_model(experiment_directory, str(epoch) + ".pth", decoder, epoch)
|
|
save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch)
|
|
save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch)
|
|
|
|
def signal_handler(sig, frame):
|
|
logging.info("Stopping early...")
|
|
sys.exit(0)
|
|
|
|
def adjust_learning_rate(lr_schedules, optimizer, epoch):
|
|
|
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
|
|
|
|
def empirical_stat(latent_vecs, indices):
|
|
lat_mat = torch.zeros(0).cuda()
|
|
for ind in indices:
|
|
lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
|
|
mean = torch.mean(lat_mat, 0)
|
|
var = torch.var(lat_mat, 0)
|
|
return mean, var
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
num_samp_per_scene = specs["SamplesPerScene"]
|
|
scene_per_batch = specs["ScenesPerBatch"]
|
|
clamp_dist = specs["ClampingDistance"]
|
|
minT = -clamp_dist
|
|
maxT = clamp_dist
|
|
enforce_minmax = True
|
|
|
|
do_code_regularization = get_spec_with_default(specs, "CodeRegularization", True)
|
|
code_reg_lambda = get_spec_with_default(specs, "CodeRegularizationLambda", 1e-4)
|
|
|
|
code_bound = get_spec_with_default(specs, "CodeBound", None)
|
|
|
|
decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"]).cuda()
|
|
|
|
logging.info("training with {} GPU(s)".format(torch.cuda.device_count()))
|
|
|
|
# if torch.cuda.device_count() > 1:
|
|
decoder = torch.nn.DataParallel(decoder)
|
|
|
|
num_epochs = specs["NumEpochs"]
|
|
log_frequency = get_spec_with_default(specs, "LogFrequency", 10)
|
|
|
|
with open(train_split_file, "r") as f:
|
|
train_split = json.load(f)
|
|
|
|
sdf_dataset = deep_sdf.data.SDFSamples(
|
|
data_source, train_split, num_samp_per_scene, load_ram=False
|
|
)
|
|
|
|
num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 1)
|
|
logging.debug("loading data with {} threads".format(num_data_loader_threads))
|
|
|
|
sdf_loader = data_utils.DataLoader(
|
|
sdf_dataset,
|
|
batch_size=scene_per_batch,
|
|
shuffle=True,
|
|
num_workers=num_data_loader_threads,
|
|
drop_last=True,
|
|
)
|
|
|
|
logging.debug("torch num_threads: {}".format(torch.get_num_threads()))
|
|
|
|
num_scenes = len(sdf_dataset)
|
|
|
|
logging.info("There are {} scenes".format(num_scenes))
|
|
|
|
logging.debug(decoder)
|
|
|
|
lat_vecs = torch.nn.Embedding(num_scenes, latent_size, max_norm=code_bound)
|
|
torch.nn.init.normal_(
|
|
lat_vecs.weight.data,
|
|
0.0,
|
|
get_spec_with_default(specs, "CodeInitStdDev", 1.0) / math.sqrt(latent_size),
|
|
)
|
|
|
|
logging.debug(
|
|
"initialized with mean magnitude {}".format(
|
|
get_mean_latent_vector_magnitude(lat_vecs)
|
|
)
|
|
)
|
|
|
|
loss_l1 = torch.nn.L1Loss(reduction="sum")
|
|
|
|
optimizer_all = torch.optim.Adam(
|
|
[
|
|
{
|
|
"params": decoder.parameters(),
|
|
"lr": lr_schedules[0].get_learning_rate(0),
|
|
},
|
|
{
|
|
"params": lat_vecs.parameters(),
|
|
"lr": lr_schedules[1].get_learning_rate(0),
|
|
},
|
|
]
|
|
)
|
|
|
|
loss_log = []
|
|
lr_log = []
|
|
lat_mag_log = []
|
|
timing_log = []
|
|
param_mag_log = {}
|
|
|
|
start_epoch = 1
|
|
|
|
if continue_from is not None:
|
|
|
|
logging.info('continuing from "{}"'.format(continue_from))
|
|
|
|
lat_epoch = load_latent_vectors(
|
|
experiment_directory, continue_from + ".pth", lat_vecs
|
|
)
|
|
|
|
model_epoch = ws.load_model_parameters(
|
|
experiment_directory, continue_from, decoder
|
|
)
|
|
|
|
optimizer_epoch = load_optimizer(
|
|
experiment_directory, continue_from + ".pth", optimizer_all
|
|
)
|
|
|
|
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, log_epoch = load_logs(
|
|
experiment_directory
|
|
)
|
|
|
|
if not log_epoch == model_epoch:
|
|
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log = clip_logs(
|
|
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, model_epoch
|
|
)
|
|
|
|
if not (model_epoch == optimizer_epoch and model_epoch == lat_epoch):
|
|
raise RuntimeError(
|
|
"epoch mismatch: {} vs {} vs {} vs {}".format(
|
|
model_epoch, optimizer_epoch, lat_epoch, log_epoch
|
|
)
|
|
)
|
|
|
|
start_epoch = model_epoch + 1
|
|
|
|
logging.debug("loaded")
|
|
|
|
logging.info("starting from epoch {}".format(start_epoch))
|
|
|
|
logging.info(
|
|
"Number of decoder parameters: {}".format(
|
|
sum(p.data.nelement() for p in decoder.parameters())
|
|
)
|
|
)
|
|
logging.info(
|
|
"Number of shape code parameters: {} (# codes {}, code dim {})".format(
|
|
lat_vecs.num_embeddings * lat_vecs.embedding_dim,
|
|
lat_vecs.num_embeddings,
|
|
lat_vecs.embedding_dim,
|
|
)
|
|
)
|
|
|
|
for epoch in range(start_epoch, num_epochs + 1):
|
|
|
|
start = time.time()
|
|
|
|
logging.info("epoch {}...".format(epoch))
|
|
|
|
decoder.train()
|
|
|
|
adjust_learning_rate(lr_schedules, optimizer_all, epoch)
|
|
|
|
for sdf_data, indices in sdf_loader:
|
|
|
|
# Process the input data
|
|
sdf_data = sdf_data.reshape(-1, 4)
|
|
|
|
num_sdf_samples = sdf_data.shape[0]
|
|
|
|
sdf_data.requires_grad = False
|
|
|
|
xyz = sdf_data[:, 0:3]
|
|
sdf_gt = sdf_data[:, 3].unsqueeze(1)
|
|
|
|
if enforce_minmax:
|
|
sdf_gt = torch.clamp(sdf_gt, minT, maxT)
|
|
|
|
xyz = torch.chunk(xyz, batch_split)
|
|
indices = torch.chunk(
|
|
indices.unsqueeze(-1).repeat(1, num_samp_per_scene).view(-1),
|
|
batch_split,
|
|
)
|
|
|
|
sdf_gt = torch.chunk(sdf_gt, batch_split)
|
|
|
|
batch_loss = 0.0
|
|
|
|
optimizer_all.zero_grad()
|
|
|
|
for i in range(batch_split):
|
|
|
|
batch_vecs = lat_vecs(indices[i])
|
|
|
|
input = torch.cat([batch_vecs, xyz[i]], dim=1)
|
|
|
|
# NN optimization
|
|
pred_sdf = decoder(input)
|
|
|
|
if enforce_minmax:
|
|
pred_sdf = torch.clamp(pred_sdf, minT, maxT)
|
|
|
|
chunk_loss = loss_l1(pred_sdf, sdf_gt[i].cuda()) / num_sdf_samples
|
|
|
|
if do_code_regularization:
|
|
l2_size_loss = torch.sum(torch.norm(batch_vecs, dim=1))
|
|
reg_loss = (
|
|
code_reg_lambda * min(1, epoch / 100) * l2_size_loss
|
|
) / num_sdf_samples
|
|
|
|
chunk_loss = chunk_loss + reg_loss.cuda()
|
|
|
|
chunk_loss.backward()
|
|
|
|
batch_loss += chunk_loss.item()
|
|
|
|
logging.debug("loss = {}".format(batch_loss))
|
|
|
|
loss_log.append(batch_loss)
|
|
|
|
if grad_clip is not None:
|
|
|
|
torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
|
|
|
|
optimizer_all.step()
|
|
|
|
end = time.time()
|
|
|
|
seconds_elapsed = end - start
|
|
timing_log.append(seconds_elapsed)
|
|
|
|
lr_log.append([schedule.get_learning_rate(epoch) for schedule in lr_schedules])
|
|
|
|
lat_mag_log.append(get_mean_latent_vector_magnitude(lat_vecs))
|
|
|
|
append_parameter_magnitudes(param_mag_log, decoder)
|
|
|
|
if epoch in checkpoints:
|
|
save_checkpoints(epoch)
|
|
|
|
if epoch % log_frequency == 0:
|
|
|
|
save_latest(epoch)
|
|
save_logs(
|
|
experiment_directory,
|
|
loss_log,
|
|
lr_log,
|
|
timing_log,
|
|
lat_mag_log,
|
|
param_mag_log,
|
|
epoch,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import argparse
|
|
|
|
arg_parser = argparse.ArgumentParser(description="Train a DeepSDF autodecoder")
|
|
arg_parser.add_argument(
|
|
"--experiment",
|
|
"-e",
|
|
dest="experiment_directory",
|
|
required=True,
|
|
help="The experiment directory. This directory should include "
|
|
+ "experiment specifications in 'specs.json', and logging will be "
|
|
+ "done in this directory as well.",
|
|
)
|
|
arg_parser.add_argument(
|
|
"--continue",
|
|
"-c",
|
|
dest="continue_from",
|
|
help="A snapshot to continue from. This can be 'latest' to continue"
|
|
+ "from the latest running snapshot, or an integer corresponding to "
|
|
+ "an epochal snapshot.",
|
|
)
|
|
arg_parser.add_argument(
|
|
"--batch_split",
|
|
dest="batch_split",
|
|
default=1,
|
|
help="This splits the batch into separate subbatches which are "
|
|
+ "processed separately, with gradients accumulated across all "
|
|
+ "subbatches. This allows for training with large effective batch "
|
|
+ "sizes in memory constrained environments.",
|
|
)
|
|
|
|
deep_sdf.add_common_args(arg_parser)
|
|
|
|
args = arg_parser.parse_args()
|
|
|
|
deep_sdf.configure_logging(args)
|
|
|
|
main_function(args.experiment_directory, args.continue_from, int(args.batch_split))
|
|
|