|
|
@ -83,35 +83,18 @@ class Net(nn.Module): |
|
|
|
|
|
|
|
#self.csg_combiner = CSGCombiner(flag_convex=True) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def forward(self, query_points): |
|
|
|
""" |
|
|
|
前向传播 |
|
|
|
|
|
|
|
参数: |
|
|
|
query_point: 查询点的位置坐标 |
|
|
|
返回: |
|
|
|
output: 解码后的输出结果 |
|
|
|
""" |
|
|
|
# 批量查询所有点的索引和bbox |
|
|
|
_,face_indices_mask,operator = self.octree_module.forward(query_points) |
|
|
|
# 编码 |
|
|
|
feature_vectors = self.encoder.forward(query_points,face_indices_mask) |
|
|
|
print("feature_vector:", feature_vectors.shape) |
|
|
|
# 解码 |
|
|
|
#logger.gpu_memory_stats("encoder farward后") |
|
|
|
f_i = self.decoder(feature_vectors) # (B, P) |
|
|
|
#logger.gpu_memory_stats("decoder farward后") |
|
|
|
|
|
|
|
|
|
|
|
def process_sdf(self,f_i, face_indices_mask, operator): |
|
|
|
output = f_i[:,0] |
|
|
|
# 提取有效值并填充到固定大小 (B, max_patches) |
|
|
|
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches) |
|
|
|
for i in range(f_i.shape[0]): |
|
|
|
sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P |
|
|
|
num_valid = min(len(sample_valid_values), 2) |
|
|
|
padded_f_i[i, :num_valid] = sample_valid_values[:num_valid] |
|
|
|
valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P) |
|
|
|
masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf |
|
|
|
|
|
|
|
# 对每个样本取前 max_patches 个有效值 (B, max_patches) |
|
|
|
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值 |
|
|
|
|
|
|
|
# 填充到固定大小 (B, max_patches) |
|
|
|
padded_f_i[:, :2] = valid_values # (B, max_patches) |
|
|
|
|
|
|
|
# 找到需要组合的行 |
|
|
|
mask_concave = (operator == 0) |
|
|
@ -125,10 +108,36 @@ class Net(nn.Module): |
|
|
|
if mask_convex.any(): |
|
|
|
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values |
|
|
|
|
|
|
|
|
|
|
|
logger.debug("step over") |
|
|
|
#logger.gpu_memory_stats("combine后") |
|
|
|
return output |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def forward(self, query_points): |
|
|
|
""" |
|
|
|
前向传播 |
|
|
|
|
|
|
|
参数: |
|
|
|
query_point: 查询点的位置坐标 |
|
|
|
返回: |
|
|
|
output: 解码后的输出结果 |
|
|
|
""" |
|
|
|
# 批量查询所有点的索引和bbox |
|
|
|
#logger.debug("step octree") |
|
|
|
_,face_indices_mask,operator = self.octree_module.forward(query_points) |
|
|
|
#logger.debug("step encode") |
|
|
|
# 编码 |
|
|
|
feature_vectors = self.encoder.forward(query_points,face_indices_mask) |
|
|
|
#print("feature_vector:", feature_vectors.shape) |
|
|
|
# 解码 |
|
|
|
#logger.debug("step decode") |
|
|
|
#logger.gpu_memory_stats("encoder farward后") |
|
|
|
f_i = self.decoder(feature_vectors) # (B, P) |
|
|
|
#logger.gpu_memory_stats("decoder farward后") |
|
|
|
|
|
|
|
#logger.debug("step combine") |
|
|
|
return self.process_sdf(f_i, face_indices_mask, operator) |
|
|
|
|
|
|
|
@torch.jit.export |
|
|
|
def forward_training_volumes(self, surf_points, patch_id:int): |
|
|
|
""" |
|
|
|