Browse Source

优化net计算

final
mckay 1 month ago
parent
commit
061ddfeda0
  1. 61
      brep2sdf/networks/network.py

61
brep2sdf/networks/network.py

@ -83,6 +83,35 @@ class Net(nn.Module):
#self.csg_combiner = CSGCombiner(flag_convex=True)
def process_sdf(self,f_i, face_indices_mask, operator):
output = f_i[:,0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P)
masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
# 填充到固定大小 (B, max_patches)
padded_f_i[:, :2] = valid_values # (B, max_patches)
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最大值
if mask_concave.any():
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.debug("step over")
#logger.gpu_memory_stats("combine后")
return output
@torch.jit.export
def forward(self, query_points):
"""
@ -94,40 +123,20 @@ class Net(nn.Module):
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
#logger.debug("step octree")
_,face_indices_mask,operator = self.octree_module.forward(query_points)
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
print("feature_vector:", feature_vectors.shape)
#print("feature_vector:", feature_vectors.shape)
# 解码
#logger.debug("step decode")
#logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
output = f_i[:, 0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
for i in range(f_i.shape[0]):
sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P
num_valid = min(len(sample_valid_values), 2)
padded_f_i[i, :num_valid] = sample_valid_values[:num_valid]
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最大值
if mask_concave.any():
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
#logger.gpu_memory_stats("combine后")
return output
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):

Loading…
Cancel
Save