diff --git a/brep2sdf/IsoSurfacing.py b/brep2sdf/IsoSurfacing.py index 2e5d5b3..c10006b 100644 --- a/brep2sdf/IsoSurfacing.py +++ b/brep2sdf/IsoSurfacing.py @@ -35,23 +35,28 @@ def predict_sdf(model, points, device, use_bk=False): """ points_t = torch.from_numpy(points).float().to(device) logger.print_tensor_stats("input poitns", points_t) + #logger.info(f"points_t:{points_t.shape}") with torch.no_grad(): if use_bk: print("only background") sdf = model.forward_background(points_t) else: - batch_size = 8192*4 # 定义批量大小 + batch_size = 8192*128 # 定义批量大小 sdf_list = [] # 用于存储批量预测结果 + model.octree_module = model.octree_module.to(points_t.device) for i in range(0, len(points), batch_size): batch_points = points[i:i + batch_size] points_t = torch.from_numpy(batch_points).float().to(device) - logger.print_tensor_stats("input points", points_t) + logger.print_tensor_stats("points_t", points_t) batch_sdf = model(points_t) + logger.print_tensor_stats("batch_sdf", batch_sdf) sdf_list.append(batch_sdf.cpu()) sdf = torch.cat(sdf_list) # 合并所有批量结果 + #logger.info(f"sdf:{sdf.shape}") logger.print_tensor_stats("sdf", sdf) sdf = sdf.cpu().numpy().flatten() + #logger.info(f"sdf:{sdf.shape}") return sdf def extract_surface(sdf, xx, yy, zz, method='MC', bbox_size=1.0,feature_angle=30.0, voxel_size=0.01): diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index 392e259..cb58419 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -49,9 +49,9 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs1: int = 10000 - num_epochs2: int = 0000 - num_epochs3: int = 0000 + num_epochs1: int = 0000 + num_epochs2: int = 000 + num_epochs3: int = 100 learning_rate: float = 0.1 learning_rate_schedule: List = field(default_factory=lambda: [{ "Type": "Step", # 学习率调度类型。"Step"表示在指定迭代次数后将学习率乘以因子 diff --git a/brep2sdf/data/sampler.py b/brep2sdf/data/sampler.py index a8fb12d..c3f7c05 100644 --- a/brep2sdf/data/sampler.py +++ b/brep2sdf/data/sampler.py @@ -438,7 +438,7 @@ def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh): np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf], 如果采样或计算失败则返回 None。 """ - grid_size = 2**5 + 1 + grid_size = 2**4 + 1 start = -0.5 end = 0.5 x = np.linspace(start, end, grid_size) diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 654fdd9..20c47bf 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -54,12 +54,8 @@ class Decoder(nn.Module): self.norm_layers.append(nn.LayerNorm(dim)) if geometric_init: - self.activation = nn.Sequential( - nn.LayerNorm(out_dim), # 添加层归一化 - nn.Softplus(beta=beta) - ) if beta > 0: - self.activation = nn.SiLU() + self.activation = nn.Softplus(beta=beta) # vanilla relu else: self.activation = nn.ReLU() @@ -68,7 +64,7 @@ class Decoder(nn.Module): self.activation = Sine() self.final_activation = nn.Tanh() - def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: + def forward_patch(self, feature_matrix: torch.Tensor) -> torch.Tensor: ''' :param feature_matrix: 形状为 (B, P, D) 的特征矩阵 B: 批大小 @@ -98,7 +94,7 @@ class Decoder(nn.Module): return f_i @torch.jit.export - def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor: + def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: ''' :param feature_matrix: 形状为(S, D) 的特征矩阵 S: 采样数量 diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 49e4803..e939e1a 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import numpy as np import time + from .octree import OctreeNode from .feature_volume import PatchFeatureVolume,SimpleFeatureEncoder from brep2sdf.utils.logger import logger @@ -89,38 +90,43 @@ class Encoder(nn.Module): volume_indices_mask: 关联的volume索引矩阵 (B, P) 返回: - 特征张量 (B, P, D) + # 获取前两个有效特征索引 + # 注意:当某行True数量小于2时: + # 1. 如果只有1个True,会重复获取该特征两次 + # 2. 如果没有True,会获取背景特征两次(因为mask最后补充了一列True) + 特征张量 (B, 2, D) """ batch_size, num_volumes = volume_indices_mask.shape - all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, + all_features = torch.zeros(batch_size, num_volumes+1, self.feature_dim, device=query_points.device) background_features = self.background.forward(query_points) # (B, D) - start_time = time.time() - # 创建 CUDA 流 - streams = [torch.cuda.Stream() for _ in range(len(self.feature_volumes))] - features_list = [None] * len(self.feature_volumes) - # 并行计算 + # 遍历每个volume索引 for vol_id, volume in enumerate(self.feature_volumes): mask = volume_indices_mask[:, vol_id].squeeze() - if not mask.any(): - continue - with torch.cuda.stream(streams[vol_id]): - features = volume(query_points[mask]) - features_list[vol_id] = (mask, features) - - # 同步流 - torch.cuda.synchronize() - - # 写入结果 - for vol_id, item in enumerate(features_list): - if item is None: - continue - mask, features = item - all_features[mask, vol_id] = 0.1 * background_features[mask] + 0.9 * features - end_time = time.time() - logger.debug(f"duration:{end_time-start_time}") - return all_features + #logger.debug(f"mask:{mask},shape:{mask.shape},mask.any():{mask.any()}") + if mask.any(): + # 获取对应volume的特征 (M, D) + features = volume.forward(query_points[mask]) + all_features[mask, vol_id] = features + # 最后一维度作为背景场 + all_features[:,num_volumes] = background_features + #all_features[:, :] = background_features.unsqueeze(1) + features = torch.zeros(batch_size, 2, self.feature_dim, + device=query_points.device) + + # mask从 volume_indices_mask(B,P) 变成 (B,P+1) ,True 补充 + mask = torch.cat([ + volume_indices_mask, + torch.ones(batch_size, 1, dtype=torch.bool, device=volume_indices_mask.device) + ], dim=1) + # 对于每个样本,取前两个非零特征, 如果没有 + + _, valid_indices = torch.topk(mask.float(), 2, dim=1) # (B, 2) + # 使用gather获取特征 + features = all_features.gather(1, valid_indices.unsqueeze(-1).expand(-1, -1, self.feature_dim)) + + return features def forward_background(self, query_points: torch.Tensor) -> torch.Tensor: """ @@ -146,10 +152,10 @@ class Encoder(nn.Module): """ # 获取 patch 特征 patch_features = self.feature_volumes[patch_id].forward(surf_points) - background_features = self.background.forward(surf_points) # (B, D) + #background_features = self.background.forward(surf_points) # (B, D) #dot = make_dot(patch_features, params=dict(self.feature_volumes.named_parameters())) #dot.render("feature_extraction", format="png") # 将计算图保存为 PDF 文件 - return 0.1 * background_features + 0.9 * patch_features + return patch_features def to(self, device): super().to(device) @@ -178,4 +184,4 @@ class Encoder(nn.Module): for param in volume.parameters(): param.requires_grad = True for param in self.background.parameters(): - param.requires_grad = True \ No newline at end of file + param.requires_grad = True diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index 8597a85..2364739 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -45,36 +45,53 @@ class PatchFeatureVolume(nn.Module): return self._batched_trilinear(normalized) def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor: - """批量处理的三线性插值""" - # 计算8个顶点的权重 - uvw = normalized * (self.resolution - 1) - indices = torch.floor(uvw).long() # (B,3) - weights = uvw - indices.float() # (B,3) + """ + 修复后的批量三线性插值 + Args: + normalized (Tensor): [B, 3],归一化坐标范围 [0, 1] + Returns: + Tensor: [B, feature_dim] + """ + B = normalized.shape[0] + device = normalized.device - # 计算8个顶点的权重组合 (B,8) - weights = torch.stack([ - (1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]), - (1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2], - (1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]), - (1 - weights[...,0]) * weights[...,1] * weights[...,2], - weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]), - weights[...,0] * (1 - weights[...,1]) * weights[...,2], - weights[...,0] * weights[...,1] * (1 - weights[...,2]), - weights[...,0] * weights[...,1] * weights[...,2], - ], dim=-1) # (B,8) + # 将归一化坐标映射到网格索引范围 [0, resolution - 1] + uvw = normalized * (self.resolution - 1) + indices = torch.floor(uvw).long() + weights = uvw - indices.float() - # 获取8个顶点的特征 (B,8,D) - indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([ - [0,0,0], [0,0,1], [0,1,0], [0,1,1], - [1,0,0], [1,0,1], [1,1,0], [1,1,1] - ], device=indices.device) - indices = torch.clamp(indices, 0, self.resolution-1) + # 确保所有维度对齐 + indices = torch.clamp(indices, 0, self.resolution - 2) # 改为resolution-2防止越界 - features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D) - - # 加权求和 (B,D) - return torch.einsum('bnd,bn->bd', features, weights) + # 获取8个顶点的坐标 + x, y, z = indices.unbind(dim=-1) + w_x, w_y, w_z = weights.unbind(dim=-1) + + # 确保权重维度正确 + w_z = w_z.unsqueeze(-1) + w_y = w_y.unsqueeze(-1) + w_x = w_x.unsqueeze(-1) + + # 获取特征值 + c00 = self.feature_volume[x, y, z ] + c01 = self.feature_volume[x, y, z + 1] + c10 = self.feature_volume[x, y + 1, z ] + c11 = self.feature_volume[x, y + 1, z + 1] + c20 = self.feature_volume[x + 1, y, z ] + c21 = self.feature_volume[x + 1, y, z + 1] + c30 = self.feature_volume[x + 1, y + 1, z ] + c31 = self.feature_volume[x + 1, y + 1, z + 1] + + # 插值计算 + c0 = c00 * (1 - w_z) + c01 * w_z + c1 = c10 * (1 - w_z) + c11 * w_z + c2 = c20 * (1 - w_z) + c21 * w_z + c3 = c30 * (1 - w_z) + c31 * w_z + + c_top = c0 * (1 - w_y) + c1 * w_y + c_bot = c2 * (1 - w_y) + c3 * w_y + return c_top * (1 - w_x) + c_bot * w_x class SimpleFeatureEncoder(nn.Module): def __init__(self, input_dim=3, feature_dim=64): diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 314f2c9..89c4a79 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -79,36 +79,25 @@ class Net(nn.Module): dims_sdf=[decoder_hidden_dim] * decoder_num_layers, #skip_in=(3,), geometric_init=True, - beta=5 + beta=100 ) #self.csg_combiner = CSGCombiner(flag_convex=True) - def process_sdf(self,f_i, face_indices_mask, operator): + def process_sdf(self, f_i, operator): output = f_i[:,0] - # 提取有效值并填充到固定大小 (B, max_patches) - padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches) - masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf - - # 对每个样本取前 max_patches 个有效值 (B, max_patches) - valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值 - - # 填充到固定大小 (B, max_patches) - padded_f_i[:, :2] = valid_values # (B, max_patches) - # 找到需要组合的行 mask_concave = (operator == 0) mask_convex = (operator == 1) - # 对 operator == 0 的样本取最大值 + # 对 operator == 0 的样本取最小值 if mask_concave.any(): - output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values + output[mask_concave] = torch.min(f_i[mask_concave], dim=1).values - # 对 operator == 1 的样本取最小值 + # 对 operator == 1 的样本取最大值 if mask_convex.any(): - output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values + output[mask_convex] = torch.max(f_i[mask_convex], dim=1).values - #logger.gpu_memory_stats("combine后") return output @torch.jit.export @@ -125,18 +114,7 @@ class Net(nn.Module): #logger.debug("step octree") _,face_indices_mask,operator = self.octree_module.forward(query_points) #logger.debug("step encode") - # 编码 - feature_vectors = self.encoder.forward(query_points,face_indices_mask) - #print("feature_vector:", feature_vectors.shape) - # 解码 - #logger.debug("step decode") - #logger.gpu_memory_stats("encoder farward后") - f_i = self.decoder(feature_vectors) # (B, P) - #logger.gpu_memory_stats("decoder farward后") - - #logger.debug("step combine") - return f_i[:,0] - return self.process_sdf(f_i, face_indices_mask, operator) + return self.forward_without_octree(query_points,face_indices_mask,operator) @torch.jit.export def forward_background(self, query_points): @@ -152,31 +130,36 @@ class Net(nn.Module): # 编码 feature_vectors = self.encoder.forward_background(query_points) # 解码 - h = self.decoder.forward_training_volumes(feature_vectors) # (B, D) + h = self.decoder.forward(feature_vectors) # (B, D) return h - @torch.jit.ignore - def forward_without_octree(self, query_points,face_indices_mask,operator): + @torch.jit.export + def forward_without_octree(self, query_points, face_indices_mask, operator): """ 前向传播 参数: - query_point: 查询点的位置坐标 + query_points: 查询点的位置坐标 (B, 3) + face_indices_mask: 面索引掩码 (B, P) + operator: 操作符 (B,) + 返回: - output: 解码后的输出结果 + output: 解码后的SDF值 (B,) """ - # 批量查询所有点的索引和bbox - #logger.debug("step encode") # 编码 - feature_vectors = self.encoder(query_points,face_indices_mask) + feature_vectors = self.encoder(query_points,face_indices_mask) # (B, 2, D) + feature_dim = feature_vectors.size(-1) # 获取特征维度 + flatten_feature_vectors = feature_vectors.reshape(-1, feature_dim) # (B*2, D) #print("feature_vector:", feature_vectors.shape) # 解码 - f_i = self.decoder(feature_vectors) # (B, P) + f_i = self.decoder(flatten_feature_vectors) # (B*2, ) + B = feature_vectors.size(0) # 获取batch size + f_i = f_i.reshape(B, 2) # 将输出reshape为(B, 2) #logger.gpu_memory_stats("decoder farward后") #logger.debug("step combine") - return f_i[:,0] - return self.process_sdf(f_i, face_indices_mask, operator) + + return self.process_sdf(f_i, operator) @torch.jit.ignore def forward_training_volumes(self, surf_points, patch_id:int): @@ -186,7 +169,7 @@ class Net(nn.Module): return (P, S) """ feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id) - f_i = self.decoder.forward_training_volumes(feature_mat) + f_i = self.decoder.forward(feature_mat) return f_i.squeeze() diff --git a/brep2sdf/scripts/farward_speed.py b/brep2sdf/scripts/farward_speed.py new file mode 100644 index 0000000..33dcce8 --- /dev/null +++ b/brep2sdf/scripts/farward_speed.py @@ -0,0 +1,81 @@ +import re + +for_log_data = """ +2025-05-21 17:35:04,438 | DEBUG  | encoder.py:forward:108 - duration:0.02291393280029297 +2025-05-21 17:35:04,487 | DEBUG  | encoder.py:forward:108 - duration:0.013659954071044922 +2025-05-21 17:35:05,096 | DEBUG  | encoder.py:forward:108 - duration:0.013151884078979492 +2025-05-21 17:35:05,128 | DEBUG  | encoder.py:forward:108 - duration:0.012245893478393555 +2025-05-21 17:35:05,667 | DEBUG  | encoder.py:forward:108 - duration:0.012324810028076172 +2025-05-21 17:35:05,698 | DEBUG  | encoder.py:forward:108 - duration:0.011858940124511719 +2025-05-21 17:35:06,247 | DEBUG  | encoder.py:forward:108 - duration:0.013593196868896484 +2025-05-21 17:35:06,278 | DEBUG  | encoder.py:forward:108 - duration:0.012105226516723633 +2025-05-21 17:35:06,829 | DEBUG  | encoder.py:forward:108 - duration:0.012081146240234375 +2025-05-21 17:35:06,859 | DEBUG  | encoder.py:forward:108 - duration:0.011334419250488281 +2025-05-21 17:35:07,404 | DEBUG  | encoder.py:forward:108 - duration:0.013489246368408203 +2025-05-21 17:35:07,436 | DEBUG  | encoder.py:forward:108 - duration:0.01230931282043457 +2025-05-21 17:35:07,983 | DEBUG  | encoder.py:forward:108 - duration:0.01315164566040039 +2025-05-21 17:35:08,015 | DEBUG  | encoder.py:forward:108 - duration:0.012539148330688477 +2025-05-21 17:35:08,569 | DEBUG  | encoder.py:forward:108 - duration:0.014146566390991211 +2025-05-21 17:35:08,602 | DEBUG  | encoder.py:forward:108 - duration:0.013015508651733398 +2025-05-21 17:35:09,156 | DEBUG  | encoder.py:forward:108 - duration:0.01263570785522461 +2025-05-21 17:35:09,186 | DEBUG  | encoder.py:forward:108 - duration:0.011255264282226562 +2025-05-21 17:35:09,722 | DEBUG  | encoder.py:forward:108 - duration:0.014206647872924805 +2025-05-21 17:35:09,754 | DEBUG  | encoder.py:forward:108 - duration:0.012360095977783203 +2025-05-21 17:35:10,307 | DEBUG  | encoder.py:forward:108 - duration:0.013350963592529297 +2025-05-21 17:35:10,339 | DEBUG  | encoder.py:forward:108 - duration:0.012225151062011719 +2025-05-21 17:35:10,894 | DEBUG  | encoder.py:forward:108 - duration:0.014019250869750977 +2025-05-21 17:35:10,925 | DEBUG  | encoder.py:forward:108 - duration:0.012645483016967773 +2025-05-21 17:35:11,477 | DEBUG  | encoder.py:forward:108 - duration:0.010942935943603516 +2025-05-21 17:35:11,494 | DEBUG  | encoder.py:forward:108 - duration:0.010617733001708984 +""" + +parallel_log_data = """ +2025-05-21 17:25:30,716 | DEBUG  | encoder.py:forward:122 - duration:0.014799833297729492 +2025-05-21 17:25:30,748 | DEBUG  | encoder.py:forward:122 - duration:0.013928413391113281 +2025-05-21 17:25:31,318 | DEBUG  | encoder.py:forward:122 - duration:0.020897626876831055 +2025-05-21 17:25:31,352 | DEBUG  | encoder.py:forward:122 - duration:0.013567924499511719 +2025-05-21 17:25:31,929 | DEBUG  | encoder.py:forward:122 - duration:0.020887374877929688 +2025-05-21 17:25:31,965 | DEBUG  | encoder.py:forward:122 - duration:0.014947652816772461 +2025-05-21 17:25:32,550 | DEBUG  | encoder.py:forward:122 - duration:0.02316737174987793 +2025-05-21 17:25:32,586 | DEBUG  | encoder.py:forward:122 - duration:0.01513051986694336 +2025-05-21 17:25:33,172 | DEBUG  | encoder.py:forward:122 - duration:0.021285295486450195 +2025-05-21 17:25:33,207 | DEBUG  | encoder.py:forward:122 - duration:0.015576839447021484 +2025-05-21 17:25:33,790 | DEBUG  | encoder.py:forward:122 - duration:0.02099466323852539 +2025-05-21 17:25:33,826 | DEBUG  | encoder.py:forward:122 - duration:0.015471696853637695 +2025-05-21 17:25:34,406 | DEBUG  | encoder.py:forward:122 - duration:0.021028518676757812 +2025-05-21 17:25:34,441 | DEBUG  | encoder.py:forward:122 - duration:0.015815019607543945 +2025-05-21 17:25:35,034 | DEBUG  | encoder.py:forward:122 - duration:0.020988941192626953 +2025-05-21 17:25:35,070 | DEBUG  | encoder.py:forward:122 - duration:0.01592278480529785 +2025-05-21 17:25:35,662 | DEBUG  | encoder.py:forward:122 - duration:0.019669532775878906 +2025-05-21 17:25:35,698 | DEBUG  | encoder.py:forward:122 - duration:0.015323638916015625 +2025-05-21 17:25:36,276 | DEBUG  | encoder.py:forward:122 - duration:0.02336907386779785 +2025-05-21 17:25:36,311 | DEBUG  | encoder.py:forward:122 - duration:0.015668869018554688 +2025-05-21 17:25:36,896 | DEBUG  | encoder.py:forward:122 - duration:0.022051572799682617 +2025-05-21 17:25:36,932 | DEBUG  | encoder.py:forward:122 - duration:0.015897512435913086 +2025-05-21 17:25:37,526 | DEBUG  | encoder.py:forward:122 - duration:0.020981311798095703 +2025-05-21 17:25:37,560 | DEBUG  | encoder.py:forward:122 - duration:0.015113353729248047 +2025-05-21 17:25:38,137 | DEBUG  | encoder.py:forward:122 - duration:0.018566131591796875 +2025-05-21 17:25:38,157 | DEBUG  | encoder.py:forward:122 - duration:0.013988733291625977 +""" +# 计算 duration 平均值 + +def run(log_data): + # 使用正则表达式提取所有的 duration 数值 + durations = re.findall(r'duration:(\d+\.\d+)', log_data) + + # 转换为浮点数列表 + durations = [float(d) for d in durations] + + # 计算平均值 + average_duration = sum(durations) / len(durations) if durations else 0 + + # 输出结果 + print(f"共找到 {len(durations)} 个 duration") + print(f"平均 duration: {average_duration:.6f} 秒") + return average_duration + + +speed1 = run(for_log_data) +speed2 = run(parallel_log_data) +print(f"for speed: {speed1:.6f} 秒") +print(f"parallel speed: {speed2:.6f} 秒") \ No newline at end of file diff --git a/brep2sdf/scripts/npz2points.py b/brep2sdf/scripts/npz2points.py index ef27e58..ebb6318 100644 --- a/brep2sdf/scripts/npz2points.py +++ b/brep2sdf/scripts/npz2points.py @@ -11,10 +11,10 @@ def load_brep_file(brep_path): if __name__ == "__main__": - data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000003.xyz") - surfs =data["train_surf_ncs"] + data=load_brep_file("/home/wch/brep2sdf/data/output_data/00000031.xyz") + surfs =data["sampled_points_normals_sdf"] print(surfs) - with open("0003_t.xyz","w") as f: + with open("0031_t.xyz","w") as f: for point in surfs: #f.write(f"{point[0]} {point[1]} {point[2]}\n") f.write(f"{point[0]} {point[1]} {point[2]} {point[3]} {point[4]} {point[5]}\n") diff --git a/brep2sdf/test.py b/brep2sdf/test.py index 17571b0..d21beda 100644 --- a/brep2sdf/test.py +++ b/brep2sdf/test.py @@ -164,7 +164,7 @@ def sample_grid(trimesh_mesh_ncs: trimesh.Trimesh): np.ndarray | None: 形状为 (N, 7) 的数组 [x, y, z, nx, ny, nz, sdf], 如果采样或计算失败则返回 None。 """ - grid_size = 2**5 + 1 + grid_size = 2**4 + 1 start = -1 end = 1 x = np.linspace(start, end, grid_size) @@ -199,6 +199,7 @@ def test2(obj_file): # 将点坐标和SDF值转换为网格格式 grid_size = int(np.cbrt(len(points))) # 假设采样点是立方体网格 + print(f"grid size:{grid_size}") sdf_grid = sdf_values.reshape((grid_size, grid_size, grid_size)) # 使用Marching Cubes提取零表面 @@ -243,5 +244,5 @@ def main(): if __name__ == "__main__": #main() #test() - test2("/home/wch/brep2sdf/data/gt_mesh/00000003.obj") + test2("/home/wch/brep2sdf/data/gt_mesh/00000031.obj") # python test.py -i /home/wch/brep2sdf/data/gt_mesh/00000003.obj -o output.ply --depth 6 --box_size 2.0 --method MC \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index e05fa7c..2bd2aa5 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -604,7 +604,7 @@ class Trainer: self.model.train() total_loss = 0.0 step = 0 # 如果你的训练是分批次的,这里应该用批次索引 - batch_size = 4096*5 # 设置合适的batch大小 + batch_size = 50000 # 设置合适的batch大小 # 数据处理 # manfld @@ -640,7 +640,7 @@ class Trainer: _nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"] _nonmnfld_operator = self.cached_train_data["nonmnfld_operator"] - logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) + #logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf)) @@ -704,7 +704,7 @@ class Trainer: mnfld_pred, nonmnfld_pred, psdf - ) + ) #logger.gpu_memory_stats("计算损失后") else: loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) @@ -747,10 +747,11 @@ class Trainer: # 记录训练进度 (只记录有效的损失) - logger.info(f'Train Epoch: {epoch:4d}]\t' - f'Loss: {total_loss:.6f}') - if loss_details: logger.info(f"Loss Details: {loss_details}") - #self.validate(epoch,total_loss) + if epoch % 10 == 0: + logger.info(f'Train Epoch: {epoch:4d}]\t' + f'Loss: {total_loss:.6f}') + if loss_details: logger.info(f"Loss Details: {loss_details}") + self.validate(epoch,total_loss) return total_loss # 对于单批次训练,直接返回当前损失 @@ -932,8 +933,8 @@ class Trainer: #stage 3 self.scheduler.reset() - self.model.freeze_stage2() - #self.model.unfreeze() + #self.model.freeze_stage2() + self.model.unfreeze() for epoch in range(cur_epoch + 1, max_stage2_epoch + self.config.train.num_epochs3 + 1): # 训练一个epoch train_loss = self.train_epoch_stage3(epoch)