You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
4.9 KiB
171 lines
4.9 KiB
#!/usr/bin/env python3
|
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
|
|
|
import glob
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
import random
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
import deep_sdf.workspace as ws
|
|
|
|
|
|
def get_instance_filenames(data_source, split):
|
|
npzfiles = []
|
|
for dataset in split:
|
|
for class_name in split[dataset]:
|
|
for instance_name in split[dataset][class_name]:
|
|
instance_filename = os.path.join(
|
|
dataset, class_name, instance_name + ".npz"
|
|
)
|
|
path = os.path.join(data_source, ws.sdf_samples_subdir, instance_filename)
|
|
#print(path)
|
|
if not os.path.isfile(path):
|
|
# raise RuntimeError(
|
|
# 'Requested non-existent file "' + instance_filename + "'"
|
|
# )
|
|
logging.warning(
|
|
"Requested non-existent file '{}'".format(instance_filename)
|
|
)
|
|
npzfiles += [instance_filename]
|
|
return npzfiles
|
|
|
|
|
|
class NoMeshFileError(RuntimeError):
|
|
"""Raised when a mesh file is not found in a shape directory"""
|
|
|
|
pass
|
|
|
|
|
|
class MultipleMeshFileError(RuntimeError):
|
|
""""Raised when a there a multiple mesh files in a shape directory"""
|
|
|
|
pass
|
|
|
|
|
|
def find_mesh_in_directory(shape_dir):
|
|
mesh_filenames = list(glob.iglob(shape_dir + "/**/*.obj")) + list(
|
|
glob.iglob(shape_dir + "/*.obj")
|
|
)
|
|
if len(mesh_filenames) == 0:
|
|
raise NoMeshFileError()
|
|
elif len(mesh_filenames) > 1:
|
|
raise MultipleMeshFileError()
|
|
return mesh_filenames[0]
|
|
|
|
|
|
def remove_nans(tensor):
|
|
tensor_nan = torch.isnan(tensor[:, 3])
|
|
return tensor[~tensor_nan, :]
|
|
|
|
|
|
def read_sdf_samples_into_ram(filename):
|
|
npz = np.load(filename)
|
|
pos_tensor = torch.from_numpy(npz["pos"])
|
|
neg_tensor = torch.from_numpy(npz["neg"])
|
|
|
|
return [pos_tensor, neg_tensor]
|
|
|
|
|
|
def unpack_sdf_samples(filename, subsample=None):
|
|
npz = np.load(filename)
|
|
if subsample is None:
|
|
return npz
|
|
pos_tensor = remove_nans(torch.from_numpy(npz["pos"]))
|
|
neg_tensor = remove_nans(torch.from_numpy(npz["neg"]))
|
|
|
|
# split the sample into half
|
|
half = int(subsample / 2)
|
|
|
|
random_pos = (torch.rand(half) * pos_tensor.shape[0]).long()
|
|
random_neg = (torch.rand(half) * neg_tensor.shape[0]).long()
|
|
|
|
sample_pos = torch.index_select(pos_tensor, 0, random_pos)
|
|
sample_neg = torch.index_select(neg_tensor, 0, random_neg)
|
|
|
|
samples = torch.cat([sample_pos, sample_neg], 0)
|
|
|
|
return samples
|
|
|
|
|
|
def unpack_sdf_samples_from_ram(data, subsample=None):
|
|
if subsample is None:
|
|
return data
|
|
pos_tensor = data[0]
|
|
neg_tensor = data[1]
|
|
|
|
# split the sample into half
|
|
half = int(subsample / 2)
|
|
|
|
pos_size = pos_tensor.shape[0]
|
|
neg_size = neg_tensor.shape[0]
|
|
|
|
pos_start_ind = random.randint(0, pos_size - half)
|
|
sample_pos = pos_tensor[pos_start_ind : (pos_start_ind + half)]
|
|
|
|
if neg_size <= half:
|
|
random_neg = (torch.rand(half) * neg_tensor.shape[0]).long()
|
|
sample_neg = torch.index_select(neg_tensor, 0, random_neg)
|
|
else:
|
|
neg_start_ind = random.randint(0, neg_size - half)
|
|
sample_neg = neg_tensor[neg_start_ind : (neg_start_ind + half)]
|
|
|
|
samples = torch.cat([sample_pos, sample_neg], 0)
|
|
|
|
return samples
|
|
|
|
|
|
class SDFSamples(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
data_source,
|
|
split,
|
|
subsample,
|
|
load_ram=False,
|
|
print_filename=False,
|
|
num_files=1000000,
|
|
):
|
|
self.subsample = subsample
|
|
|
|
self.data_source = data_source
|
|
self.npyfiles = get_instance_filenames(data_source, split)
|
|
|
|
logging.debug(
|
|
"using "
|
|
+ str(len(self.npyfiles))
|
|
+ " shapes from data source "
|
|
+ data_source
|
|
)
|
|
|
|
self.load_ram = load_ram
|
|
|
|
if load_ram:
|
|
self.loaded_data = []
|
|
for f in self.npyfiles:
|
|
filename = os.path.join(self.data_source, ws.sdf_samples_subdir, f)
|
|
npz = np.load(filename)
|
|
pos_tensor = remove_nans(torch.from_numpy(npz["pos"]))
|
|
neg_tensor = remove_nans(torch.from_numpy(npz["neg"]))
|
|
self.loaded_data.append(
|
|
[
|
|
pos_tensor[torch.randperm(pos_tensor.shape[0])],
|
|
neg_tensor[torch.randperm(neg_tensor.shape[0])],
|
|
]
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.npyfiles)
|
|
|
|
def __getitem__(self, idx):
|
|
filename = os.path.join(
|
|
self.data_source, ws.sdf_samples_subdir, self.npyfiles[idx]
|
|
)
|
|
if self.load_ram:
|
|
return (
|
|
unpack_sdf_samples_from_ram(self.loaded_data[idx], self.subsample),
|
|
idx,
|
|
)
|
|
else:
|
|
return unpack_sdf_samples(filename, self.subsample), idx
|
|
|