Browse Source

优化net计算

final
mckay 1 month ago
parent
commit
061ddfeda0
  1. 63
      brep2sdf/networks/network.py

63
brep2sdf/networks/network.py

@ -83,35 +83,18 @@ class Net(nn.Module):
#self.csg_combiner = CSGCombiner(flag_convex=True)
@torch.jit.export
def forward(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
_,face_indices_mask,operator = self.octree_module.forward(query_points)
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
print("feature_vector:", feature_vectors.shape)
# 解码
#logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
output = f_i[:, 0]
def process_sdf(self,f_i, face_indices_mask, operator):
output = f_i[:,0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
for i in range(f_i.shape[0]):
sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P
num_valid = min(len(sample_valid_values), 2)
padded_f_i[i, :num_valid] = sample_valid_values[:num_valid]
valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P)
masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
# 填充到固定大小 (B, max_patches)
padded_f_i[:, :2] = valid_values # (B, max_patches)
# 找到需要组合的行
mask_concave = (operator == 0)
@ -125,10 +108,36 @@ class Net(nn.Module):
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.debug("step over")
#logger.gpu_memory_stats("combine后")
return output
@torch.jit.export
def forward(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
#logger.debug("step octree")
_,face_indices_mask,operator = self.octree_module.forward(query_points)
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape)
# 解码
#logger.debug("step decode")
#logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):
"""

Loading…
Cancel
Save