From 061ddfeda025e34e0b849eaf5a8ec2b55ce66c53 Mon Sep 17 00:00:00 2001 From: mckay Date: Mon, 28 Apr 2025 20:25:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96net=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/network.py | 63 ++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index f8e7736..dcfb81a 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -83,35 +83,18 @@ class Net(nn.Module): #self.csg_combiner = CSGCombiner(flag_convex=True) - @torch.jit.export - def forward(self, query_points): - """ - 前向传播 - - 参数: - query_point: 查询点的位置坐标 - 返回: - output: 解码后的输出结果 - """ - # 批量查询所有点的索引和bbox - _,face_indices_mask,operator = self.octree_module.forward(query_points) - # 编码 - feature_vectors = self.encoder.forward(query_points,face_indices_mask) - print("feature_vector:", feature_vectors.shape) - # 解码 - #logger.gpu_memory_stats("encoder farward后") - f_i = self.decoder(feature_vectors) # (B, P) - #logger.gpu_memory_stats("decoder farward后") - - - output = f_i[:, 0] + def process_sdf(self,f_i, face_indices_mask, operator): + output = f_i[:,0] # 提取有效值并填充到固定大小 (B, max_patches) padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches) - for i in range(f_i.shape[0]): - sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P - num_valid = min(len(sample_valid_values), 2) - padded_f_i[i, :num_valid] = sample_valid_values[:num_valid] + valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P) + masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf + + # 对每个样本取前 max_patches 个有效值 (B, max_patches) + valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值 + # 填充到固定大小 (B, max_patches) + padded_f_i[:, :2] = valid_values # (B, max_patches) # 找到需要组合的行 mask_concave = (operator == 0) @@ -125,10 +108,36 @@ class Net(nn.Module): if mask_convex.any(): output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values - + logger.debug("step over") #logger.gpu_memory_stats("combine后") return output + @torch.jit.export + def forward(self, query_points): + """ + 前向传播 + + 参数: + query_point: 查询点的位置坐标 + 返回: + output: 解码后的输出结果 + """ + # 批量查询所有点的索引和bbox + #logger.debug("step octree") + _,face_indices_mask,operator = self.octree_module.forward(query_points) + #logger.debug("step encode") + # 编码 + feature_vectors = self.encoder.forward(query_points,face_indices_mask) + #print("feature_vector:", feature_vectors.shape) + # 解码 + #logger.debug("step decode") + #logger.gpu_memory_stats("encoder farward后") + f_i = self.decoder(feature_vectors) # (B, P) + #logger.gpu_memory_stats("decoder farward后") + + #logger.debug("step combine") + return self.process_sdf(f_i, face_indices_mask, operator) + @torch.jit.export def forward_training_volumes(self, surf_points, patch_id:int): """