diff --git a/brep2sdf/config/default_config.py b/brep2sdf/config/default_config.py index d8ed5b9..1018b45 100644 --- a/brep2sdf/config/default_config.py +++ b/brep2sdf/config/default_config.py @@ -48,7 +48,7 @@ class TrainConfig: # 基本训练参数 batch_size: int = 8 num_workers: int = 4 - num_epochs: int = 100 + num_epochs: int = 1000 learning_rate: float = 0.001 min_lr: float = 1e-5 weight_decay: float = 0.01 @@ -90,7 +90,7 @@ class LogConfig: # 本地日志 log_dir: str = '/home/wch/brep2sdf/logs' # 日志保存目录 log_level: str = 'INFO' # 日志级别 - console_level: str = 'DEBUG' # 控制台日志级别 + console_level: str = 'INFO' # 控制台日志级别 file_level: str = 'DEBUG' # 文件日志级别 @dataclass diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py index 1a1bf2b..088e9b0 100644 --- a/brep2sdf/data/utils.py +++ b/brep2sdf/data/utils.py @@ -10,6 +10,9 @@ from OCC.Core.TopoDS import topods, TopoDS_Vertex # 拓扑数据结构 import numpy as np from scipy.spatial import cKDTree +from torch.nn.utils.rnn import pad_sequence +import torch + from brep2sdf.utils.logger import logger def load_step(step_path): @@ -35,7 +38,22 @@ def get_bbox(shape, subshape): xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() return np.array([xmin, ymin, zmin, xmax, ymax, zmax]) - +def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor: + """ + 使用 pad_sequence 动态填充 surf_ncs。 + + 参数: + surf_ncs: 形状为 (N,) 的 np.ndarray(dtype=object),每个元素是形状为 (M, 3) 的 float32 数组。 + + 返回: + padded_tensor: 形状为 (N, M_max, 3) 的张量,其中 M_max 是最长子数组的长度。 + """ + # 转换为张量列表 + tensor_list = [torch.tensor(arr, dtype=torch.float32) for arr in surf_ncs] + + # 动态填充 + padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=float('inf')) + return padded_tensor def normalize(surfs, edges, corners): diff --git a/brep2sdf/networks/decoder.py b/brep2sdf/networks/decoder.py index 9f9ea57..7e8d142 100644 --- a/brep2sdf/networks/decoder.py +++ b/brep2sdf/networks/decoder.py @@ -1,42 +1,219 @@ import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor +import torch +import numpy as np +from typing import Tuple, List, Union +from brep2sdf.utils.logger import logger +class Sine(nn.Module): + def __init(self): + super().__init__() + + def forward(self, input): + # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 + return torch.sin(30 * input) class Decoder(nn.Module): - def __init__(self, - input_dim: int, - output_dim: int, - hidden_dim: int = 256) : - """ - 最简单的Decoder实现 - - 参数: - input_dim: 输入维度 - output_dim: 输出维度 - hidden_dim: 隐藏层维度 (默认: 256) - """ + def __init__( + self, + d_in: int, + dims_sdf: List[int], + skip_in: Tuple[int, ...] = (), + flag_convex: bool = True, + geometric_init: bool = True, + radius_init: float = 1, + beta: float = 100, + ) -> None: super().__init__() - - # 三层全连接网络 - self.fc1 = nn.Linear(input_dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, hidden_dim) - self.fc3 = nn.Linear(hidden_dim, output_dim) + + self.flag_convex = flag_convex + self.skip_in = skip_in + + dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] + self.sdf_layers = len(dims_sdf) + for layer in range(0, len(dims_sdf) - 1): + if layer + 1 in skip_in: + out_dim = dims_sdf[layer + 1] - d_in + else: + out_dim = dims_sdf[layer + 1] + lin = nn.Linear(dims_sdf[layer], out_dim) + if geometric_init: + if layer == self.sdf_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001) + torch.nn.init.constant_(lin.bias, -radius_init) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + setattr(self, "sdf_"+str(layer), lin) + if geometric_init: + if beta > 0: + self.activation = nn.Softplus(beta=beta) + # vanilla relu + else: + self.activation = nn.ReLU() + else: + #siren + self.activation = Sine() + self.final_activation = nn.ReLU() + + # composite f_i to h - def forward(self, x: Tensor) -> Tensor: - """ - 前向传播 + + + def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: + ''' + :param feature_matrix: 形状为 (B, P, D) 的特征矩阵 + B: 批大小 + P: patch volume数量 + D: 特征维度 + :return: + f_i: 各patch的SDF值 (B, P) + ''' + B, P, D = feature_matrix.shape + + # 展平处理 (B*P, D) + x = feature_matrix.view(-1, D) + + for layer in range(0, self.sdf_layers - 1): + lin = getattr(self, "sdf_" + str(layer)) + if layer in self.skip_in: + x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' + + x = lin(x) + if layer < self.sdf_layers - 2: + x = self.activation(x) + output_value = x #all f_i + + # 恢复维度 (B, P) + f_i = output_value.view(B, P) + + return f_i + + +# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h +#注意考虑如何批量处理 (B, P) 和 [csg tree] +class CSGCombiner: + def __init__(self, flag_convex: bool = True, rho: float = 0.05): + self.flag_convex = flag_convex + self.rho = rho + + def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor: + ''' + :param f_i: 形状为 (B, P) 的各patch SDF值 + :param csg_tree: CSG树结构 + :return: 组合后的整体SDF (B,) + ''' + logger.info("\n".join(f"第{i}个csg: {t}" for i,t in enumerate(csg_tree))) + B = f_i.shape[0] + results = [] + for i in range(B): + # 处理每个样本的CSG组合 + h = self.nested_cvx_output_soft_blend( + f_i[i].unsqueeze(0), + csg_tree, + self.flag_convex + ) + results.append(h) + + return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,) + + def nested_cvx_output_soft_blend( + self, + value_matrix: torch.Tensor, + list_operation: List[Union[int, List]], + cvx_flag: bool = True + ) -> torch.Tensor: + list_value = [] + for v in list_operation: + if not isinstance(v, list): + list_value.append(v) + + op_mat = torch.zeros(value_matrix.shape[1], len(list_value), + device=value_matrix.device) + for i in range(len(list_value)): + op_mat[list_value[i]][i] = 1.0 + + mat_mul = torch.matmul(value_matrix, op_mat) + if len(list_operation) == len(list_value): + return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \ + else self.min_soft_blend(mat_mul, self.rho) - 参数: - x: 输入张量 - 返回: - 输出张量 - """ - # 第一层 - h = F.relu(self.fc1(x)) - # 第二层 - h = F.relu(self.fc2(h)) - # 输出层 - out = self.fc3(h) + list_output = [mat_mul] + for v in list_operation: + if isinstance(v, list): + list_output.append( + self.nested_cvx_output_soft_blend( + value_matrix, v, not cvx_flag + ) + ) - return out \ No newline at end of file + return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \ + else self.min_soft_blend(torch.cat(list_output, 1), self.rho) + + def min_soft_blend(self, mat, rho): + res = mat[:,0] + for i in range(1, mat.shape[1]): + srho = res * res + mat[:,i] * mat[:,i] - rho * rho + res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) + return res.unsqueeze(1) + + def max_soft_blend(self, mat, rho): + res = mat[:,0] + for i in range(1, mat.shape[1]): + srho = res * res + mat[:,i] * mat[:,i] - rho * rho + res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) + return res.unsqueeze(1) + + +def test_csg_combiner(): + # 测试数据 (B=3, P=5) + f_i = torch.tensor([ + [1.0, 2.0, 3.0, 4.0, 5.0], + [0.5, 1.5, 2.5, 3.5, 4.5], + [-1.0, 0.0, 1.0, 2.0, 3.0] + ]) + + # 每个样本使用不同的CSG树结构 + csg_trees = [ + [0, [1, 2]], # 使用索引0,1,2 + [[0, 1], 3], # 使用索引0,1,3 + [0, 1, [2, 4]] # 使用索引0,1,2,4 + ] + + # 验证所有索引都有效 + P = f_i.shape[1] + for i, tree in enumerate(csg_trees): + def check_indices(node): + if isinstance(node, list): + for n in node: + check_indices(n) + else: + assert node < P, f"样本{i}的树包含无效索引{node},P={P}" + check_indices(tree) + + print("Input SDF values:") + print(f_i) + print("\nCSG Trees:") + for i, tree in enumerate(csg_trees): + print(f"Sample {i}: {tree}") + + # 测试凸组合 + print("\nTesting convex combination:") + combiner_convex = CSGCombiner(flag_convex=True) + h_convex = combiner_convex.forward(f_i, csg_trees) + print("Results:", h_convex) + + # 测试凹组合 + print("\nTesting concave combination:") + combiner_concave = CSGCombiner(flag_convex=False) + h_concave = combiner_concave.forward(f_i, csg_trees) + print("Results:", h_concave) + + # 测试不同rho值的软混合 + print("\nTesting soft blends:") + for rho in [0.01, 0.1, 0.5]: + combiner_soft = CSGCombiner(flag_convex=True, rho=rho) + h_soft = combiner_soft.forward(f_i, csg_trees) + print(f"rho={rho}:", h_soft) + +if __name__ == "__main__": + test_csg_combiner() \ No newline at end of file diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 6d3ddcf..03d2c53 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -2,72 +2,89 @@ import torch import torch.nn as nn from .octree import OctreeNode +from .feature_volume import PatchFeatureVolume from brep2sdf.utils.logger import logger class Encoder(nn.Module): - def __init__(self, octree: OctreeNode, feature_dim: int = 32): + def __init__(self, volume_bboxs:torch.tensor, feature_dim: int = 32): """ 分离后的编码器,接收预构建的八叉树 参数: - octree: 预构建的八叉树结构 + volume_bboxs: 所有面片的边界框集合,形状为 (N, 2, 3) feature_dim: 特征维度 """ super().__init__() self.feature_dim = feature_dim - # 初始化叶子节点参数 - self.register_buffer('num_parameters', torch.tensor(0, dtype=torch.long)) - self._leaf_features = None # 将在_init_parameters中初始化 - self._init_parameters(octree) - - def _init_parameters(self,octree): - stack = [(octree, 0)] - param_count = 0 - - while stack: - node, _ = stack.pop() - if node._is_leaf: - param_count += 1 - else: - for child in node.child_nodes: - if child: stack.append((child, 0)) + # 批量计算所有bbox的分辨率 + resolutions = self._batch_calculate_resolution(volume_bboxs) + + # 初始化多个特征体积 + self.feature_volumes = nn.ModuleList([ + PatchFeatureVolume( + bbox=bbox, + resolution=int(resolutions[i]), + feature_dim=feature_dim + ) for i, bbox in enumerate(volume_bboxs) + ]) + print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}") + print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB") - # 初始化连续参数张量 - self._leaf_features = nn.Parameter( - torch.randn(param_count, 8, self.feature_dim)) + def _batch_calculate_resolution(self, bboxes: torch.Tensor) -> torch.Tensor: + """ + 批量计算归一化bboxes的分辨率 - # 重新遍历设置索引 - stack = [(octree, 0)] - index = 0 - while stack: - node, _ = stack.pop() - if node._is_leaf: - node.set_param_key(index) - index += 1 - else: - for child in node.child_nodes: - if child: stack.append((child, 0)) - self.num_parameters.fill_(index) + 参数: + bboxes: 归一化边界框张量,形状为 (N, 2, 3) + + 返回: + 分辨率张量 (N,) + """ + with torch.no_grad(): + # 计算每个bbox的对角线长度(归一化后范围约为0.0-1.732) + diagonals = torch.norm(bboxes[:,3:6] - bboxes[:,0:3], dim=1) + + # 根据归一化后的对角线长度调整分辨率 + resolutions = torch.zeros_like(diagonals, dtype=torch.long) + resolutions[diagonals > 1.0] = 16 # 大尺寸 + resolutions[(diagonals > 0.5) & (diagonals <= 1.0)] = 8 # 中等尺寸 + resolutions[diagonals <= 0.5] = 4 # 小尺寸 + + return resolutions - def forward(self, query_points: torch.Tensor,param_indices,bboxes) -> torch.Tensor: - batch_size = query_points.shape[0] - - # 批量获取特征 - unique_ids, inverse_ids = torch.unique(param_indices, return_inverse=True) - all_features = self._leaf_features[unique_ids] # (U, 8, D) - node_features = all_features[inverse_ids] # (B, 8, D) - # 启用混合精度和优化后的插值 - with torch.cuda.amp.autocast(): - features = self._optimized_trilinear( - query_points, - bboxes.detach(), - node_features - ) + def forward(self, query_points: torch.Tensor, volume_indices: torch.Tensor) -> torch.Tensor: + """ + 修改后的前向传播,返回所有关联volume的特征矩阵 + + 参数: + query_points: 查询点坐标 (B, 3) + volume_indices: 关联的volume索引矩阵 (B, K) - # 添加类型转换确保输出为float32 - return features.to(torch.float32) # 添加这行 + 返回: + 特征张量 (B, K, D) + """ + batch_size, num_volumes = volume_indices.shape + all_features = torch.zeros(batch_size, num_volumes, self.feature_dim, + device=query_points.device) + + # 遍历每个volume索引 + for k in range(num_volumes): + # 获取当前volume的索引 (B,) + current_indices = volume_indices[:, k] + + # 遍历所有存在的volume + for vol_id in range(len(self.feature_volumes)): + # 创建掩码 (B,) + mask = (current_indices == vol_id) + if mask.any(): + # 获取对应volume的特征 (M, D) + features = self.feature_volumes[vol_id](query_points[mask]) + all_features[mask, k] = features + + return all_features + def _optimized_trilinear(self, points, bboxes, features): """优化后的向量化三线性插值""" diff --git a/brep2sdf/networks/feature_volume.py b/brep2sdf/networks/feature_volume.py index 52caa21..7257a14 100644 --- a/brep2sdf/networks/feature_volume.py +++ b/brep2sdf/networks/feature_volume.py @@ -4,50 +4,72 @@ import torch import torch.nn as nn class PatchFeatureVolume(nn.Module): - def __init__(self, bbox:np, resolution=64, feature_dim=64): + def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=64, padding_ratio=0.05): super(PatchFeatureVolume, self).__init__() - self.bbox = bbox # 补丁的边界框 - self.resolution = resolution # 网格分辨率 - self.feature_dim = feature_dim # 特征向量维度 - + # 将输入bbox转换为[min, max]格式 + self.resolution = resolution + min_coords = bbox[:3] + max_coords = bbox[3:] + self.original_bbox = torch.stack([min_coords, max_coords]) + expanded_bbox = self._expand_bbox(min_coords, max_coords, padding_ratio) # 创建规则的三维网格 - x = torch.linspace(bbox[0][0], bbox[1][0], resolution) - y = torch.linspace(bbox[0][1], bbox[1][1], resolution) - z = torch.linspace(bbox[0][2], bbox[1][2], resolution) + x = torch.linspace(expanded_bbox[0][0], expanded_bbox[1][0], resolution) + y = torch.linspace(expanded_bbox[0][1], expanded_bbox[1][1], resolution) + z = torch.linspace(expanded_bbox[0][2], expanded_bbox[1][2], resolution) grid_x, grid_y, grid_z = torch.meshgrid(x, y, z) self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1)) - # 初始化特征向量,作为可训练参数 + # 初始化特征向量 self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim)) - def forward(self, query_points: List[Tuple[float, float, float]]): + def _expand_bbox(self, min_coords, max_coords, ratio): + # 扩展包围盒范围 + center = (min_coords + max_coords) / 2 + expanded_min = center - (center - min_coords) * (1 + ratio) + expanded_max = center + (max_coords - center) * (1 + ratio) + return torch.stack([expanded_min, expanded_max]) + + def forward(self, query_points: torch.Tensor) -> torch.Tensor: + """批量处理版本的三线性插值 + Args: + query_points: 形状为 (B, 3) 的查询点坐标 + Returns: + 形状为 (B, D) 的特征向量 """ - 根据查询点的位置,从补丁特征体积中获取插值后的特征向量。 + # 添加类型转换确保计算稳定性 + normalized = ((query_points - self.grid[0,0,0]) / + (self.grid[-1,-1,-1] - self.grid[0,0,0] + 1e-8)) # (B,3) - :param query_points: 查询点的位置坐标,形状为 (N, 3) - :return: 插值后的特征向量,形状为 (N, feature_dim) - """ - interpolated_features = torch.zeros(query_points.shape[0], self.feature_dim).to(self.feature_volume.device) - for i, point in enumerate(query_points): - interpolated_feature = self.trilinear_interpolation(point) - interpolated_features[i] = interpolated_feature - return interpolated_features - - def trilinear_interpolation(self, query_point): - """三线性插值""" - normalized_coords = ((query_point - torch.tensor(self.bbox[0]).to(self.grid.device)) / - (torch.tensor(self.bbox[1]).to(self.grid.device) - torch.tensor(self.bbox[0]).to(self.grid.device))) * (self.resolution - 1) - indices = torch.floor(normalized_coords).long() - weights = normalized_coords - indices.float() + # 向量化三线性插值 + return self._batched_trilinear(normalized) - interpolated_feature = torch.zeros(self.feature_dim).to(self.feature_volume.device) - for di in range(2): - for dj in range(2): - for dk in range(2): - weight = (weights[0] if di == 1 else 1 - weights[0]) * \ - (weights[1] if dj == 1 else 1 - weights[1]) * \ - (weights[2] if dk == 1 else 1 - weights[2]) - index = indices + torch.tensor([di, dj, dk]).to(indices.device) - index = torch.clamp(index, 0, self.resolution - 1) - interpolated_feature += weight * self.feature_volume[index[0], index[1], index[2]] - return interpolated_feature \ No newline at end of file + def _batched_trilinear(self, normalized: torch.Tensor) -> torch.Tensor: + """批量处理的三线性插值""" + # 计算8个顶点的权重 + uvw = normalized * (self.resolution - 1) + indices = torch.floor(uvw).long() # (B,3) + weights = uvw - indices.float() # (B,3) + + # 计算8个顶点的权重组合 (B,8) + weights = torch.stack([ + (1 - weights[...,0]) * (1 - weights[...,1]) * (1 - weights[...,2]), + (1 - weights[...,0]) * (1 - weights[...,1]) * weights[...,2], + (1 - weights[...,0]) * weights[...,1] * (1 - weights[...,2]), + (1 - weights[...,0]) * weights[...,1] * weights[...,2], + weights[...,0] * (1 - weights[...,1]) * (1 - weights[...,2]), + weights[...,0] * (1 - weights[...,1]) * weights[...,2], + weights[...,0] * weights[...,1] * (1 - weights[...,2]), + weights[...,0] * weights[...,1] * weights[...,2], + ], dim=-1) # (B,8) + + # 获取8个顶点的特征 (B,8,D) + indices = indices.unsqueeze(1).expand(-1,8,-1) + torch.tensor([ + [0,0,0], [0,0,1], [0,1,0], [0,1,1], + [1,0,0], [1,0,1], [1,1,0], [1,1,1] + ], device=indices.device) + indices = torch.clamp(indices, 0, self.resolution-1) + + features = self.feature_volume[indices[...,0], indices[...,1], indices[...,2]] # (B,8,D) + + # 加权求和 (B,D) + return torch.einsum('bnd,bn->bd', features, weights) \ No newline at end of file diff --git a/brep2sdf/networks/network.py b/brep2sdf/networks/network.py index 10c8f25..15e0c64 100644 --- a/brep2sdf/networks/network.py +++ b/brep2sdf/networks/network.py @@ -49,11 +49,13 @@ import torch import torch.nn as nn from torch.autograd import grad from .encoder import Encoder -from .decoder import Decoder +from .decoder import Decoder, CSGCombiner +from brep2sdf.utils.logger import logger class Net(nn.Module): def __init__(self, octree, + volume_bboxs, feature_dim=64, decoder_input_dim=64, decoder_output_dim=1, @@ -69,11 +71,18 @@ class Net(nn.Module): # 初始化 Encoder self.encoder = Encoder( feature_dim=feature_dim, - octree=octree + volume_bboxs= volume_bboxs ) # 初始化 Decoder - self.decoder = Decoder(input_dim=64, output_dim=1) + self.decoder = Decoder( + d_in=decoder_input_dim, + dims_sdf=[decoder_hidden_dim] * decoder_num_layers, + geometric_init=True, + beta=100 + ) + + self.csg_combiner = CSGCombiner(flag_convex=True) def forward(self, query_points): """ @@ -85,14 +94,16 @@ class Net(nn.Module): output: 解码后的输出结果 """ # 批量查询所有点的索引和bbox - param_indices,bboxes = self.octree_module.forward(query_points) - print("param_indices requires_grad:", param_indices.requires_grad) # 应该输出False - print("bboxes requires_grad:", bboxes.requires_grad) # 应该输出False + _,face_indices,csg_trees = self.octree_module.forward(query_points) # 编码 - feature_vector = self.encoder.forward(query_points,param_indices,bboxes) - print("feature_vector:", feature_vector.requires_grad) + feature_vectors = self.encoder.forward(query_points,face_indices) + #print("feature_vector:", feature_vectors.requires_grad) # 解码 - output = self.decoder(feature_vector) + logger.gpu_memory_stats("encoder farward后") + f_i = self.decoder(feature_vectors) + logger.gpu_memory_stats("decoder farward后") + output = self.csg_combiner.forward(f_i, csg_trees) + logger.gpu_memory_stats("combine后") return output diff --git a/brep2sdf/networks/octree.py b/brep2sdf/networks/octree.py index 4c8f3c9..18a46ac 100644 --- a/brep2sdf/networks/octree.py +++ b/brep2sdf/networks/octree.py @@ -4,12 +4,14 @@ import torch import torch.nn as nn import numpy as np +from brep2sdf.data.utils import process_surf_ncs_with_dynamic_padding from brep2sdf.networks.patch_graph import PatchGraph +from brep2sdf.utils.logger import logger -def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: +def bbox_intersect_(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: """判断两个轴对齐包围盒(AABB)是否相交 参数: @@ -28,8 +30,90 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> torch.Tensor: # 向量化比较 return torch.all((max1 >= min2) & (max2 >= min1)) +def if_points_in_box(points: np.ndarray, bbox: torch.Tensor) -> bool: + """判断点是否在AABB包围盒内 + + 参数: + points: 形状为 (N, 3) 的数组,表示N个点的坐标 + bbox: 形状为 (6,) 的张量,表示AABB包围盒的坐标 + + 返回: + bool: 如果所有点都在包围盒内,返回True,否则返回False + """ + # 将 points 转换为 torch.Tensor + points_tensor = torch.tensor(points, dtype=torch.float32, device=bbox.device) + + # 提取min和max坐标 + min_coords = bbox[:3] + max_coords = bbox[3:] + #logger.debug(f"min_coords: {min_coords}, max_coords: {max_coords}") + # 向量化比较 + return torch.any((points_tensor >= min_coords) & (points_tensor <= max_coords)).item() + +def bbox_intersect( + surf_bboxes: torch.Tensor, + indices: torch.Tensor, + child_bboxes: torch.Tensor, + surf_points: torch.Tensor = None +) -> torch.Tensor: + ''' + args: + surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 + indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 + child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 + surf_points: [B, M_max, 3] - 每个包围盒对应的点云数据(可选)。 + return: + result_mask: [8, B] - 布尔掩码,表示每个子边界框与所有包围盒是否相交, + 且是否包含至少一个点(如果提供了点云)。 + ''' + # 初始化全为 False 的结果掩码 [8, B] + B = surf_bboxes.size(0) + result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device) + logger.debug(result_mask.shape) + logger.debug(indices.shape) + + # 提取选中的边界框 + selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] + min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] + min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] + + logger.debug(selected_bboxes.shape) + # 计算子包围盒与选中包围盒的交集 + intersect_mask = torch.all( + (max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] + (max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] + dim=-1 + ) # 最终形状为 [8, N] + + # 更新结果掩码中选中的部分 + result_mask[:, indices] = intersect_mask + + # 如果提供了点云,进一步检查点是否在子包围盒内 + if surf_points is not None: + # 提取选中的点云 + selected_points = surf_points[indices] # 形状为 [N, M_max, 3] + + # 将点云广播到子边界框的维度 + points_expanded = selected_points.unsqueeze(1) # 形状为 [N, 1, M_max, 3] + min2_expanded = min2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3] + max2_expanded = max2.unsqueeze(0).unsqueeze(2) # 形状为 [1, 8, 1, 3] + + # 判断点是否在子边界框内 + point_in_box_mask = ( + (points_expanded >= min2_expanded) & # 形状为 [N, 8, M_max, 3] + (points_expanded <= max2_expanded) # 形状为 [N, 8, M_max, 3] + ).all(dim=-1) # 最终形状为 [N, 8, M_max] + + # 检查每个子边界框是否包含至少一个点 + points_in_boxes_mask = point_in_box_mask.any(dim=-1).permute(1, 0) # 形状为 [8, N] + + # 合并交集条件和点云条件 + result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask + logger.debug(result_mask.shape) + return result_mask + class OctreeNode(nn.Module): - def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,device=None): + def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None): super().__init__() self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 改为普通张量属性 @@ -39,52 +123,43 @@ class OctreeNode(nn.Module): self.child_indices = None self.is_leaf_mask = None # 面片索引张量 - self.face_indices = torch.from_numpy(face_indices).to(self.device) + self.all_face_indices = torch.from_numpy(face_indices).to(self.device) self.surf_bbox = surf_bbox.to(self.device) if surf_bbox is not None else None - + self.surf_ncs = process_surf_ncs_with_dynamic_padding(surf_ncs).to(self.device) # PatchGraph作为普通属性 self.patch_graph = patch_graph.to(self.device) if patch_graph is not None else None self.max_depth = max_depth - # 参数键改为普通张量 - self.param_key = torch.tensor(-1, dtype=torch.long, device=self.device) self._is_leaf = True - # 删除所有register_buffer调用 - - @torch.jit.export - def set_param_key(self, k: int) -> None: - """设置参数键值 - - 参数: - k: 参数索引值 - """ - self.param_key.fill_(k) - @torch.jit.export def build_static_tree(self) -> None: """构建静态八叉树结构""" # 预计算所有可能的节点数量,确保结果为整数 total_nodes = int(sum(8**i for i in range(self.max_depth + 1))) + num_faces = self.all_face_indices.shape[0] # 初始化静态张量,使用整数列表作为形状参数 self.node_bboxes = torch.zeros([int(total_nodes), 6], device=self.device) self.parent_indices = torch.full([int(total_nodes)], -1, dtype=torch.long, device=self.device) self.child_indices = torch.full([int(total_nodes), 8], -1, dtype=torch.long, device=self.device) - self.is_leaf_mask = torch.zeros([int(total_nodes)], dtype=torch.bool, device=self.device) + self.face_indices_mask = torch.zeros([int(total_nodes),num_faces], dtype=torch.bool, device=self.device) # 1 代表有 + self.is_leaf_mask = torch.ones([int(total_nodes)], dtype=torch.bool, device=self.device) # 使用队列进行广度优先遍历 - queue = [(0, self.bbox, self.face_indices)] # (node_idx, bbox, face_indices) + queue = [(0, self.bbox, self.all_face_indices)] # (node_idx, bbox, face_indices) current_idx = 0 while queue: node_idx, bbox, faces = queue.pop(0) + + #logger.debug(f"Processing node {node_idx} with {len(faces)} faces.") self.node_bboxes[node_idx] = bbox # 判断 要不要继续分裂 - if not self._should_split_node(current_idx): - self.is_leaf_mask[node_idx] = True + if not self._should_split_node(current_idx, faces, total_nodes): continue - + + self.is_leaf_mask[node_idx] = 0 # 计算子节点边界框 min_coords = bbox[:3] max_coords = bbox[3:] @@ -92,36 +167,35 @@ class OctreeNode(nn.Module): # 生成8个子节点 child_bboxes = self._generate_child_bboxes(min_coords, mid_coords, max_coords) + intersect_mask = bbox_intersect(self.surf_bbox, faces, child_bboxes) + self.face_indices_mask[current_idx + 1:current_idx + 9, :] = intersect_mask # 为每个子节点分配面片 for i, child_bbox in enumerate(child_bboxes): - child_idx = current_idx + 1 - current_idx += 1 - - # 找到与子包围盒相交的面 - intersecting_faces = [] - for face_idx in faces: - face_bbox = self.surf_bbox[face_idx] - if bbox_intersect(child_bbox, face_bbox).item(): - intersecting_faces.append(face_idx) + child_idx = child_idx = current_idx + i + 1 + intersecting_faces = intersect_mask[i].nonzero().flatten() + #logger.debug(f"Node {child_idx} has {len(intersecting_faces)} intersecting faces.") # 更新节点关系 self.parent_indices[child_idx] = node_idx self.child_indices[node_idx, i] = child_idx # 将子节点加入队列 - if intersecting_faces: - queue.append((child_idx, child_bbox, torch.tensor(intersecting_faces, device=self.device))) + if len(intersecting_faces) > 0: + queue.append((child_idx, child_bbox, intersecting_faces.clone().detach())) + current_idx += 8 - def _should_split_node(self, current_depth: int) -> bool: + def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool: """判断节点是否需要分裂""" # 检查是否达到最大深度 - if current_depth >= self.max_depth: + if current_idx + 8 >= max_node: return False # 检查是否为完全图 - is_clique = self.patch_graph.is_clique(self.face_indices) + #is_clique = self.patch_graph.is_clique(face_indices) + is_clique = face_indices.shape[0] < 2 if is_clique: + #logger.debug(f"Node {current_idx} is a clique. Stopping split.") return False return True @@ -153,9 +227,9 @@ class OctreeNode(nn.Module): @torch.jit.export def find_leaf(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]: """ - 查找包含给定点的叶子节点,并返回其信息 + 修改后的查找叶子节点方法,返回face indices :param query_points: 待查找的点,形状为 (3,) - :return: 包含叶子节点信息的元组 (bbox, param_key, is_leaf) + :return: (bbox, param_key, face_indices, is_leaf) """ # 确保输入是单个点 if query_points.dim() != 1 or query_points.shape[0] != 3: @@ -168,7 +242,28 @@ class OctreeNode(nn.Module): while iteration < max_iterations: # 获取当前节点的叶子状态 if self.is_leaf_mask[current_idx].item(): - return self.node_bboxes[current_idx], self.param_key, True + #logger.debug(f"Reached leaf node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") + if self.face_indices_mask[current_idx].sum() == 0: + parent_idx = self.parent_indices[current_idx] + #logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.") + if parent_idx == -1: + # 根节点没有父节点,返回根节点的信息 + #logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.") + return ( + self.node_bboxes[current_idx], + None, # 新增返回face indices + False + ) + return ( + self.node_bboxes[parent_idx], + self.face_indices_mask[parent_idx], # 新增返回face indices + False + ) + return ( + self.node_bboxes[current_idx], + self.face_indices_mask[current_idx], # 新增返回face indices + True + ) # 计算子节点索引 child_idx = self._get_child_indices(query_points.unsqueeze(0), @@ -185,7 +280,7 @@ class OctreeNode(nn.Module): iteration += 1 # 如果达到最大迭代次数,返回当前节点的信息 - return self.node_bboxes[current_idx], self.param_key, bool(self.is_leaf_mask[current_idx].item()) + return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item()) @torch.jit.export def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: @@ -195,43 +290,70 @@ class OctreeNode(nn.Module): def forward(self, query_points): with torch.no_grad(): - param_indices, bboxes = [], [] + bboxes, face_indices_mask, csg_trees = [], [], [] for point in query_points: - bbox, idx, _ = self.find_leaf(point) - param_indices.append(idx) + bbox, faces_mask, _ = self.find_leaf(point) bboxes.append(bbox) - param_indices = torch.stack(param_indices) - bboxes = torch.stack(bboxes) - # 添加检查代码 - return param_indices, bboxes + face_indices_mask.append(faces_mask) + # 获取当前节点的CSG树结构 + csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None + csg_trees.append(csg_tree) # 保持原始列表结构 + return ( + torch.stack(bboxes), + torch.stack(face_indices_mask), + csg_trees # 直接返回列表,不转换为张量 + ) - def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None: + def print_tree(self, max_print_depth: int = None) -> None: """ - 递归打印八叉树结构 + 使用深度优先遍历 (DFS) 打印树结构,父子关系通过缩进体现。 参数: - depth: 当前深度 (内部使用) - max_print_depth: 最大打印深度 (None表示打印全部) + max_print_depth (int): 最大打印深度 (None 表示打印全部) """ - if max_print_depth is not None and depth > max_print_depth: - return + def dfs(node_idx: int, depth: int): + """ + 深度优先遍历辅助函数。 - # 打印当前节点信息 - indent = " " * depth - node_type = "Leaf" if self._is_leaf else "Internal" - print(f"{indent}L{depth} [{node_type}] BBox: {self.bbox.cpu().numpy().tolist()}") - - # 打印面片信息(如果有) - if self.face_indices is not None: - print(f"{indent} Face indices: {self.face_indices.cpu().numpy().tolist()}") - print(f"{indent} Child indices: {self.child_indices.cpu().numpy().tolist()}") - - # 打印子节点信息 - if self.child_indices is not None: - for i in range(8): - child_idx = self.child_indices[0, i].item() - if child_idx != -1: - print(f"{indent} Child {i}: Node {child_idx}") + 参数: + node_idx (int): 当前节点索引 + depth (int): 当前节点的深度 + """ + # 如果超过最大打印深度,跳过当前节点及其子节点 + if max_print_depth is not None and depth > max_print_depth: + return + + indent = " " * depth # 根据深度生成缩进 + is_leaf = self.is_leaf_mask[node_idx].item() # 判断是否为叶子节点 + bbox = self.node_bboxes[node_idx].cpu().numpy().tolist() # 获取边界框信息 + + # 打印当前节点的基本信息 + node_type = "Leaf" if is_leaf else "Internal" + log_lines.append(f"{indent}L{depth} [{node_type}] NODE_ID-{node_idx}, BBox: {bbox}") + if self.face_indices_mask is not None: + face_indices = self.face_indices_mask[node_idx].nonzero().cpu().numpy().flatten().tolist() + log_lines.append(f"{indent} Face Indices: {face_indices}") + # 如果是叶子节点,打印额外信息 + if is_leaf: + + child_indices = self.child_indices[node_idx].cpu().numpy().tolist() + log_lines.append(f"{indent} Child Indices: {child_indices}") + + # 如果不是叶子节点,递归处理子节点 + if not is_leaf: + for i in range(8): # 遍历所有子节点 + child_idx = self.child_indices[node_idx, i].item() + if child_idx != -1: # 忽略无效的子节点索引 + dfs(child_idx, depth + 1) + + # 初始化日志行列表 + log_lines = [] + + # 从根节点开始深度优先遍历 + dfs(0, 0) + + # 统一输出所有日志 + logger.debug("\n".join(log_lines)) def __getstate__(self): """支持pickle序列化""" @@ -245,7 +367,6 @@ class OctreeNode(nn.Module): 'surf_bbox': self.surf_bbox, 'patch_graph': self.patch_graph, 'max_depth': self.max_depth, - 'param_key': self.param_key, '_is_leaf': self._is_leaf } return state @@ -261,7 +382,6 @@ class OctreeNode(nn.Module): self.surf_bbox = state['surf_bbox'] self.patch_graph = state['patch_graph'] self.max_depth = state['max_depth'] - self.param_key = state['param_key'] self._is_leaf = state['_is_leaf'] def to(self, device=None, dtype=None, non_blocking=False): diff --git a/brep2sdf/networks/patch_graph.py b/brep2sdf/networks/patch_graph.py index 914193e..0239942 100644 --- a/brep2sdf/networks/patch_graph.py +++ b/brep2sdf/networks/patch_graph.py @@ -2,6 +2,7 @@ from typing import Tuple import torch import torch.nn as nn import numpy as np +from brep2sdf.utils.logger import logger class PatchGraph(nn.Module): def __init__(self, num_patches: int, device: torch.device = None): @@ -56,13 +57,45 @@ class PatchGraph(nn.Module): return subgraph_edges, subgraph_types - + def get_csg_tree(self, node_faces_mask: torch.Tensor): + """生成CSG组合树结构 + 参数: + node_faces: 要处理的面片索引集合,形状为 (N,) + 返回: + 嵌套列表结构,表示CSG组合层次 + 示例: + [[0, [1,2]], 3] 表示0与(1和2的组合)进行凹组合,然后与3进行凸组合 + """ + print("node_faces_mask:", node_faces_mask) + if self.edge_index is None: + return [] + node_faces = node_faces_mask.nonzero() + node_faces = node_faces.flatten().to('cpu').numpy() + logger.debug(f"node_faces: {node_faces}") + node_set = set(node_faces) # 创建输入面片的集合用于快速查找 + visited = set() + csg_tree = [] + + # 优先处理凹边连接 + concave_edges = self.edge_index[:, self.edge_type == 0].cpu().numpy().T + for u, v in concave_edges: + u, v = int(u), int(v) + if u in node_set and v in node_set and u not in visited and v not in visited: + csg_tree.append([u, v]) + visited.update({u, v}) + + # 处理剩余面片(只包含输入的面片) + remaining = [int(f) for f in node_faces if f not in visited] + csg_tree.extend(remaining) + + return csg_tree def is_clique(self, node_faces: torch.Tensor) -> bool: """检查给定面片集合是否构成完全图 参数: node_faces: 要检查的面片索引集合 + face: [0,1,2,3,4,] 返回: bool: 是否为完全图 diff --git a/brep2sdf/test.py b/brep2sdf/test.py index 498d4b8..4266e2c 100644 --- a/brep2sdf/test.py +++ b/brep2sdf/test.py @@ -1,4 +1,69 @@ import torch -model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") -print(model) \ No newline at end of file +from typing import List, Tuple + +def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor: + ''' + args: + surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 + indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 + child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 + return: + intersect_mask: [8, N] - 布尔掩码,表示每个子包围盒与选择的包围盒是否相交。 + ''' + # 提取选中的边界框 + selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] + min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] + min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] + + # 确保广播机制正常工作 + intersect_mask = torch.all( + (max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] + (max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] + dim=-1 + ) # 最终形状为 [8, N] + + return intersect_mask + +# 测试程序 +if __name__ == "__main__": + # 构造输入数据 + surf_bboxes = torch.tensor([ + [0, 0, 0, 1, 1, 1], # 立方体 1 + [0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2 + [2, 2, 2, 3, 3, 3] # 立方体 3 + ]) # [B=3, 6] + + indices = torch.tensor([0, 1]) # 选择前两个立方体 + + # 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框 + parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2]) + center = (parent_bbox[:3] + parent_bbox[3:]) / 2 + child_bboxes = torch.tensor([ + [parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前 + [center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前 + [parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前 + [center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前 + [parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后 + [center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后 + [parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后 + [center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后 + ]) # [8, 6] + + # 调用函数 + intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes) + + # 输出结果 + print("Intersect Mask:") + print(intersect_mask) + + # 将布尔掩码转换为索引列表 + child_indices = [] + for i in range(8): # 遍历每个子节点 + intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引 + child_indices.append(intersecting_faces) + + # 打印每个子节点对应的相交索引 + print("\nChild Indices:") + for i, indices in enumerate(child_indices): + print(f"Child {i}: {indices}") \ No newline at end of file diff --git a/brep2sdf/train.py b/brep2sdf/train.py index 64fe51e..a665afb 100644 --- a/brep2sdf/train.py +++ b/brep2sdf/train.py @@ -116,11 +116,12 @@ class Trainer: device=self.device ) - self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=8) + self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6) logger.gpu_memory_stats("数初始化后") self.model = Net( octree=self.root, + volume_bboxs=surf_bbox, feature_dim=64 ).to(self.device) logger.gpu_memory_stats("模型初始化后") @@ -138,7 +139,7 @@ class Trainer: logger.info(f"初始化完成,正在处理模型 {self.model_name}") - def build_tree(self,surf_bbox, graph, max_depth=6): + def build_tree(self,surf_bbox, graph, max_depth=9): num_faces = surf_bbox.shape[0] bbox = self._calculate_global_bbox(surf_bbox) self.root = OctreeNode( @@ -147,13 +148,13 @@ class Trainer: patch_graph=graph, max_depth=max_depth, surf_bbox=surf_bbox, - + surf_ncs=self.data['surf_ncs'] ) #print(surf_bbox) logger.info("starting octree conduction") self.root.build_static_tree() logger.info("complete octree conduction") - #self.root.print_tree(0) + self.root.print_tree() def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor: """ @@ -190,10 +191,6 @@ class Trainer: def train_epoch(self, epoch: int) -> float: - self.model.train() - total_loss = 0.0 - step = 0 # 如果你的训练是分批次的,这里应该用批次索引 - # --- 1. 检查输入数据 --- # 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列) # 并且 SDF 值总是在最后一列 @@ -201,114 +198,124 @@ class Trainer: logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.") return float('inf') - points = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点 - gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值 - normals = None - if args.use_normal: - if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线 - logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") - return float('inf') - normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线 - - # 执行检查 - if self.debug_mode: - if check_tensor(points, "Input Points", epoch, step): return float('inf') - if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') + self.model.train() + total_loss = 0.0 + step = 0 # 如果你的训练是分批次的,这里应该用批次索引 + batch_size = 10240 # 设置合适的batch大小 + + # 将数据分成多个batch + num_points = self.sdf_data.shape[0] + num_batches = (num_points + batch_size - 1) // batch_size + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min((batch_idx + 1) * batch_size, num_points) + points = self.sdf_data[start_idx:end_idx, 0:3].clone().detach() # 取前3列作为点 + gt_sdf = self.sdf_data[start_idx:end_idx, -1].clone().detach() # 取最后一列作为SDF真值 + normals = None if args.use_normal: - # 只有在请求法线时才检查 normals - if check_tensor(normals, "Input Normals", epoch, step): return float('inf') + if self.sdf_data.shape[1] < 7: # 检查是否有足够的列给法线 + logger.error(f"Epoch {epoch}: --use-normal is specified, but sdf_data has only {self.sdf_data.shape[1]} columns.") + return float('inf') + normals = self.sdf_data[start_idx:end_idx, 3:6].clone().detach() # 取中间3列作为法线 + + # 执行检查 + if self.debug_mode: + if check_tensor(points, "Input Points", epoch, step): return float('inf') + if check_tensor(gt_sdf, "Input GT SDF", epoch, step): return float('inf') + if args.use_normal: + # 只有在请求法线时才检查 normals + if check_tensor(normals, "Input Normals", epoch, step): return float('inf') - # --- 准备模型输入,启用梯度 --- - points.requires_grad_(True) # 在检查之后启用梯度 - - # --- 前向传播 --- - self.optimizer.zero_grad() - pred_sdf = self.model(points) + # --- 准备模型输入,启用梯度 --- + points.requires_grad_(True) # 在检查之后启用梯度 - if self.debug_mode: - # --- 检查前向传播的输出 --- - logger.gpu_memory_stats("前向传播后") - # --- 2. 检查模型输出 --- - #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') + # --- 前向传播 --- + self.optimizer.zero_grad() + pred_sdf = self.model(points) - # --- 计算损失 --- - loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 - loss_details = {} - try: - # --- 3. 检查损失计算前的输入 --- - # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) - #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") - #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") - if args.use_normal: - # 检查法线和带梯度的点 - #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") - #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") - logger.gpu_memory_stats("计算损失前") - loss, loss_details = self.loss_manager.compute_loss( - points, - normals, # 传递检查过的 normals - gt_sdf, - pred_sdf - ) - else: - loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) - - # --- 4. 检查损失计算结果 --- if self.debug_mode: - if check_tensor(loss, "Calculated Loss", epoch, step): - logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") - if loss_details: logger.error(f"Loss Details: {loss_details}") - return float('inf') # 如果损失无效,停止这个epoch - - except Exception as loss_e: - logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) - return float('inf') # 如果计算出错,停止这个epoch - logger.gpu_memory_stats("损失计算后") - - # --- 反向传播和优化 --- - try: - loss.backward() - - # --- 5. (可选) 检查梯度 --- - # for name, param in self.model.named_parameters(): - # if param.grad is not None: - # if check_tensor(param.grad, f"Gradient/{name}", epoch, step): - # logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.") - # # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪 - # # 或在 optimizer.step() 前进行范数裁剪: - # # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # --- (推荐) 添加梯度裁剪 --- - # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 - - self.optimizer.step() - except Exception as backward_e: - logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) - # 如果你想看是哪个操作导致的,可以启用 anomaly detection - # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 - return float('inf') # 如果反向传播或优化出错,停止这个epoch - - - # --- 记录和累加损失 --- - current_loss = loss.item() - if not np.isfinite(current_loss): # 再次确认损失是有效的数值 - logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") - return float('inf') + # --- 检查前向传播的输出 --- + logger.gpu_memory_stats("前向传播后") + # --- 2. 检查模型输出 --- + #if check_tensor(pred_sdf, "Predicted SDF (Model Output)", epoch, step): return float('inf') + + # --- 计算损失 --- + loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败 + loss_details = {} + try: + # --- 3. 检查损失计算前的输入 --- + # (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf) + #if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss") + #if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss") + if args.use_normal: + # 检查法线和带梯度的点 + #if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss") + #if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss") + logger.gpu_memory_stats("计算损失前") + loss, loss_details = self.loss_manager.compute_loss( + points, + normals, # 传递检查过的 normals + gt_sdf, + pred_sdf + ) + else: + loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf) + + # --- 4. 检查损失计算结果 --- + if self.debug_mode: + if check_tensor(loss, "Calculated Loss", epoch, step): + logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.") + if loss_details: logger.error(f"Loss Details: {loss_details}") + return float('inf') # 如果损失无效,停止这个epoch + + except Exception as loss_e: + logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True) + return float('inf') # 如果计算出错,停止这个epoch + logger.gpu_memory_stats("损失计算后") + + # --- 反向传播和优化 --- + try: + loss.backward() + + # --- 5. (可选) 检查梯度 --- + # for name, param in self.model.named_parameters(): + # if param.grad is not None: + # if check_tensor(param.grad, f"Gradient/{name}", epoch, step): + # logger.warning(f"Epoch {epoch} Step {step}: Bad gradient for {name}. Consider clipping or zeroing.") + # # 例如:param.grad.data.clamp_(-1, 1) # 值裁剪 + # # 或在 optimizer.step() 前进行范数裁剪: + # # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # --- (推荐) 添加梯度裁剪 --- + # 防止梯度爆炸,这可能是导致 inf/nan 的原因之一 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪 + + self.optimizer.step() + except Exception as backward_e: + logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True) + # 如果你想看是哪个操作导致的,可以启用 anomaly detection + # torch.autograd.set_detect_anomaly(True) # 放在训练开始前 + return float('inf') # 如果反向传播或优化出错,停止这个epoch + + + # --- 记录和累加损失 --- + current_loss = loss.item() + if not np.isfinite(current_loss): # 再次确认损失是有效的数值 + logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).") + return float('inf') - total_loss += current_loss + total_loss += current_loss + del loss + torch.cuda.empty_cache() + # 记录训练进度 (只记录有效的损失) logger.info(f'Train Epoch: {epoch:4d}]\t' f'Loss: {current_loss:.6f}') if loss_details: logger.info(f"Loss Details: {loss_details}") - # (如果你的训练分批次,这里应该继续循环下一批次) - # step += 1 - del loss - torch.cuda.empty_cache() # 清空缓存 - return total_loss # 对于单批次训练,直接返回当前损失 def validate(self, epoch: int) -> float: diff --git a/brep2sdf/utils/logger.py b/brep2sdf/utils/logger.py index 33bb409..ee58090 100644 --- a/brep2sdf/utils/logger.py +++ b/brep2sdf/utils/logger.py @@ -227,7 +227,7 @@ class BRepLogger: stats.append(f" 峰值: {max_allocated:.1f} MB") # 一次性输出所有统计信息 - self.info("\n".join(stats)) + self.debug("\n".join(stats)) # 获取每个张量的内存使用情况 if include_trace: