Browse Source

可以训练,但是和torch jit还有一些兼容问题

final
mckay 1 month ago
parent
commit
f98c9ac227
  1. 3
      brep2sdf/config/default_config.py
  2. 1
      brep2sdf/data/data.py
  3. 1
      brep2sdf/data/utils.py
  4. 33
      brep2sdf/networks/decoder.py
  5. 22
      brep2sdf/networks/encoder.py
  6. 52
      brep2sdf/networks/network.py
  7. 89
      brep2sdf/networks/octree.py
  8. 34
      brep2sdf/networks/patch_graph.py
  9. 137
      brep2sdf/train.py

3
brep2sdf/config/default_config.py

@ -9,6 +9,7 @@ class ModelConfig:
embed_dim: int = 768 # 3 的 倍数
latent_dim: int = 32
octree_max_depth = 6
# 点云采样配置
num_surf_points: int = 64 # 每个面采样点数
num_edge_points: int = 8 # 每条边采样点数
@ -48,7 +49,7 @@ class TrainConfig:
# 基本训练参数
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 1000
num_epochs: int = 1
learning_rate: float = 0.001
min_lr: float = 1e-5
weight_decay: float = 0.01

1
brep2sdf/data/data.py

@ -141,7 +141,6 @@ def prepare_sdf_data(surf_data, normals=None, max_points=100000, device='cuda'):
return torch.tensor(sdf_array, dtype=torch.float32, device=device)
def print_data_distribution(data: torch.Tensor) -> None:
"""打印数据分布统计信息

1
brep2sdf/data/utils.py

@ -56,6 +56,7 @@ def process_surf_ncs_with_dynamic_padding(surf_ncs: np.ndarray) -> torch.Tensor:
return padded_tensor
def normalize(surfs, edges, corners):
"""
将CAD模型归一化到单位立方体空间

33
brep2sdf/networks/decoder.py

@ -88,7 +88,34 @@ class Decoder(nn.Module):
return f_i
@torch.jit.export
def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为(S, D) 的特征矩阵
S: 采样数量
D: 特征维度
:return:
f: 各patch的SDF值 (S)
'''
# 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix
for layer in range(0, self.sdf_layers - 1):
lin = getattr(self, "sdf_" + str(layer))
if layer in self.skip_in:
x = torch.cat([x, x], -1) / np.sqrt(2)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x # 所有 f 的值
# 调整输出形状为 (S)
f = output_value.squeeze(-1)
return f
"""
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
#注意考虑如何批量处理 (B, P) 和 [csg tree]
class CSGCombiner:
@ -96,7 +123,8 @@ class CSGCombiner:
self.flag_convex = flag_convex
self.rho = rho
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor:
def forward(self, f_i: torch.Tensor, csg_tree
) -> torch.Tensor:
'''
:param f_i: 形状为 (B, P) 的各patch SDF值
:param csg_tree: CSG树结构
@ -216,4 +244,5 @@ def test_csg_combiner():
print(f"rho={rho}:", h_soft)
if __name__ == "__main__":
test_csg_combiner()
test_csg_combiner()
"""

22
brep2sdf/networks/encoder.py

@ -75,18 +75,34 @@ class Encoder(nn.Module):
current_indices = volume_indices[:, k]
# 遍历所有存在的volume
for vol_id in range(len(self.feature_volumes)):
for vol_id, volume in enumerate(self.feature_volumes):
# 创建掩码 (B,)
mask = (current_indices == vol_id)
if mask.any():
# 获取对应volume的特征 (M, D)
features = self.feature_volumes[vol_id](query_points[mask])
features = volume.forward(query_points[mask])
all_features[mask, k] = features
return all_features
@torch.jit.export
def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor:
"""
处理表面采样点的特征提取
参数:
surf_points: 表面采样点 (S, 3), S: #sampled point per feature_volumes.
返回:
特征张量 (S, D)
"""
# 使用枚举遍历 feature_volumes,避免直接索引
for idx, volume in enumerate(self.feature_volumes):
if idx == patch_id:
return volume.forward(surf_points)
return torch.zeros(surf_points.shape[0], self.feature_dim, device=surf_points.device)
return features
def _optimized_trilinear(self, points, bboxes, features):
def _optimized_trilinear(self, points, bboxes, features)-> torch.Tensor:
"""优化后的向量化三线性插值"""
# 添加显式类型转换确保计算稳定性
min_coords = bboxes[..., :3].to(torch.float32)

52
brep2sdf/networks/network.py

@ -49,7 +49,7 @@ import torch
import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder
from .decoder import Decoder, CSGCombiner
from .decoder import Decoder
from brep2sdf.utils.logger import logger
class Net(nn.Module):
@ -82,8 +82,9 @@ class Net(nn.Module):
beta=100
)
self.csg_combiner = CSGCombiner(flag_convex=True)
#self.csg_combiner = CSGCombiner(flag_convex=True)
@torch.jit.export
def forward(self, query_points):
"""
前向传播
@ -94,18 +95,55 @@ class Net(nn.Module):
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
_,face_indices,csg_trees = self.octree_module.forward(query_points)
_,face_indices_mask,operator = self.octree_module.forward(query_points)
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices)
#print("feature_vector:", feature_vectors.requires_grad)
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
print("feature_vector:", feature_vectors.shape)
# 解码
logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors)
f_i = self.decoder(feature_vectors) # (B, P)
logger.gpu_memory_stats("decoder farward后")
output = self.csg_combiner.forward(f_i, csg_trees)
output = f_i[:, 0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
for i in range(f_i.shape[0]):
sample_valid_values = f_i[i][face_indices_mask[i]] # (N,), N <= P
num_valid = min(len(sample_valid_values), 2)
padded_f_i[i, :num_valid] = sample_valid_values[:num_valid]
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最大值
if mask_concave.any():
output[mask_concave] = torch.max(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.gpu_memory_stats("combine后")
return output
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):
"""
only surf sampled points
surf_points (P, S):
return (P, S)
"""
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze()
def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)

89
brep2sdf/networks/octree.py

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple,List, Optional
import torch
import torch.nn as nn
@ -54,7 +54,7 @@ def bbox_intersect(
surf_bboxes: torch.Tensor,
indices: torch.Tensor,
child_bboxes: torch.Tensor,
surf_points: torch.Tensor = None
surf_points: Optional[torch.Tensor]=None
) -> torch.Tensor:
'''
args:
@ -69,15 +69,15 @@ def bbox_intersect(
# 初始化全为 False 的结果掩码 [8, B]
B = surf_bboxes.size(0)
result_mask = torch.zeros((8, B), dtype=torch.bool).to(surf_bboxes.device)
logger.debug(result_mask.shape)
logger.debug(indices.shape)
#logger.debug(result_mask.shape)
#logger.debug(indices.shape)
# 提取选中的边界框
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6]
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3]
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3]
logger.debug(selected_bboxes.shape)
#logger.debug(selected_bboxes.shape)
# 计算子包围盒与选中包围盒的交集
intersect_mask = torch.all(
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3]
@ -109,11 +109,11 @@ def bbox_intersect(
# 合并交集条件和点云条件
result_mask[:, indices] = result_mask[:, indices] & points_in_boxes_mask
logger.debug(result_mask.shape)
#logger.debug(result_mask.shape)
return result_mask
class OctreeNode(nn.Module):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: torch.Tensor = None, patch_graph: PatchGraph = None,surf_ncs:np.ndarray = None,device=None):
def __init__(self, bbox: torch.Tensor, face_indices: np.ndarray, max_depth: int = 5, surf_bbox: Optional[torch.Tensor]=None, patch_graph: Optional[PatchGraph] = None,surf_ncs:Optional[np.ndarray] = None,device:Optional[torch.device]=None):
super().__init__()
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 改为普通张量属性
@ -185,7 +185,7 @@ class OctreeNode(nn.Module):
queue.append((child_idx, child_bbox, intersecting_faces.clone().detach()))
current_idx += 8
def _should_split_node(self, current_idx: int,face_indices,max_node:int) -> bool:
def _should_split_node(self, current_idx: int,face_indices:torch.Tensor,max_node:int) -> bool:
"""判断节点是否需要分裂"""
# 检查是否达到最大深度
if current_idx + 8 >= max_node:
@ -229,7 +229,7 @@ class OctreeNode(nn.Module):
"""
修改后的查找叶子节点方法返回face indices
:param query_points: 待查找的点形状为 (3,)
:return: (bbox, param_key, face_indices, is_leaf)
:return: (bbox, face_indices, is_leaf)
"""
# 确保输入是单个点
if query_points.dim() != 1 or query_points.shape[0] != 3:
@ -248,10 +248,9 @@ class OctreeNode(nn.Module):
#logger.debug(f"Use parent node {parent_idx}, with {self.face_indices_mask[parent_idx].sum()} faces.")
if parent_idx == -1:
# 根节点没有父节点,返回根节点的信息
#logger.warning(f"Reached root node {current_idx}, with {self.face_indices_mask[current_idx].sum()} faces.")
return (
self.node_bboxes[current_idx],
None, # 新增返回face indices
torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device), # 新增返回face indices
False
)
return (
@ -280,7 +279,7 @@ class OctreeNode(nn.Module):
iteration += 1
# 如果达到最大迭代次数,返回当前节点的信息
return self.node_bboxes[current_idx], None,bool(self.is_leaf_mask[current_idx].item())
return self.node_bboxes[current_idx], torch.tensor([], dtype=torch.float32, device=self.node_bboxes.device),bool(self.is_leaf_mask[current_idx].item())
@torch.jit.export
def _get_child_indices(self, points: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
@ -288,23 +287,26 @@ class OctreeNode(nn.Module):
mid_points = (bboxes[:, :3] + bboxes[:, 3:]) / 2
return ((points >= mid_points) << torch.arange(3, device=points.device)).sum(dim=1)
def forward(self, query_points):
def forward(self, query_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.no_grad():
bboxes, face_indices_mask, csg_trees = [], [], []
bboxes: List[torch.Tensor] = []
face_indices_mask: List[torch.Tensor] = []
operator: List[int] = []
for point in query_points:
bbox, faces_mask, _ = self.find_leaf(point)
bbox, faces_mask, _ = self.find_leaf(point)
bboxes.append(bbox)
face_indices_mask.append(faces_mask)
# 获取当前节点的CSG树结构
csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
csg_trees.append(csg_tree) # 保持原始列表结构
#csg_tree = self.patch_graph.get_csg_tree(faces_mask) if self.patch_graph else None
#csg_trees.append(csg_tree) # 保持原始列表结构
operator.append(self.patch_graph.get_operator(faces_mask.nonzero()) if self.patch_graph is not None else -1)
return (
torch.stack(bboxes),
torch.stack(face_indices_mask),
csg_trees # 直接返回列表,不转换为张量
torch.tensor(operator, dtype=torch.int) # 直接返回列表,不转换为张量
)
def print_tree(self, max_print_depth: int = None) -> None:
def print_tree(self, max_print_depth: Optional[int] = None) -> None:
"""
使用深度优先遍历 (DFS) 打印树结构父子关系通过缩进体现
@ -353,36 +355,45 @@ class OctreeNode(nn.Module):
dfs(0, 0)
# 统一输出所有日志
logger.debug("\n".join(log_lines))
#logger.debug("\n".join(log_lines))
def __getstate__(self):
"""支持pickle序列化"""
state = {
'bbox': self.bbox,
'node_bboxes': self.node_bboxes,
'parent_indices': self.parent_indices,
'child_indices': self.child_indices,
'is_leaf_mask': self.is_leaf_mask,
'face_indices': self.face_indices,
'surf_bbox': self.surf_bbox,
'patch_graph': self.patch_graph,
'bbox': self.bbox.cpu(), # 转换为CPU张量
'node_bboxes': self.node_bboxes.cpu() if self.node_bboxes is not None else None,
'parent_indices': self.parent_indices.cpu() if self.parent_indices is not None else None,
'child_indices': self.child_indices.cpu() if self.child_indices is not None else None,
'is_leaf_mask': self.is_leaf_mask.cpu() if self.is_leaf_mask is not None else None,
'all_face_indices': self.all_face_indices.cpu(),
'face_indices_mask':self.face_indices_mask.cpu() if self.face_indices_mask is not None else None,
'surf_bbox': self.surf_bbox.cpu() if self.surf_bbox is not None else None,
'surf_ncs': self.surf_ncs.cpu() if self.surf_ncs is not None else None,
'patch_graph': self.patch_graph, # 假设PatchGraph支持序列化
'max_depth': self.max_depth,
'_is_leaf': self._is_leaf
'device': str(self.device) # 保存设备信息
}
return state
def __setstate__(self, state):
"""支持pickle反序列化"""
self.bbox = state['bbox']
self.node_bboxes = state['node_bboxes']
self.parent_indices = state['parent_indices']
self.child_indices = state['child_indices']
self.is_leaf_mask = state['is_leaf_mask']
self.face_indices = state['face_indices']
self.surf_bbox = state['surf_bbox']
self.patch_graph = state['patch_graph']
self.max_depth = state['max_depth']
self._is_leaf = state['_is_leaf']
# 手动调用 __init__ 方法
self.__init__(
bbox=state['bbox'],
face_indices=state['all_face_indices'].cpu().numpy(),
patch_graph=state['patch_graph'],
max_depth=state['max_depth'],
surf_bbox=state['surf_bbox'],
surf_ncs=state['surf_ncs'],
device=torch.device(state['device'])
)
# 可以在这里设置其他不需要在 __init__ 中处理的属性
self.node_bboxes = state['node_bboxes'].to(self.device) if state['node_bboxes'] is not None else None
self.parent_indices = state['parent_indices'].to(self.device) if state['parent_indices'] is not None else None
self.child_indices = state['child_indices'].to(self.device) if state['child_indices'] is not None else None
self.face_indices_mask = state['face_indices_mask'].to(self.device) if state['face_indices_mask'] is not None else None
self.is_leaf_mask = state['is_leaf_mask'].to(self.device) if state['is_leaf_mask'] is not None else None
def to(self, device=None, dtype=None, non_blocking=False):
# 调用父类方法迁移基础参数

34
brep2sdf/networks/patch_graph.py

@ -71,7 +71,7 @@ class PatchGraph(nn.Module):
return []
node_faces = node_faces_mask.nonzero()
node_faces = node_faces.flatten().to('cpu').numpy()
logger.debug(f"node_faces: {node_faces}")
#logger.debug(f"node_faces: {node_faces}")
node_set = set(node_faces) # 创建输入面片的集合用于快速查找
visited = set()
csg_tree = []
@ -89,6 +89,32 @@ class PatchGraph(nn.Module):
csg_tree.extend(remaining)
return csg_tree
def get_operator(self, node_faces: torch.Tensor):
# node_faces: shape (<=2,)
# 返回 0: 凹边, 1: 凸边,
node_faces = node_faces.flatten().to(self.device)
num_faces = node_faces.numel()
if num_faces == 1:
# 这里设置凸边是因为 后续会补一个 f2 = inf, h = min(f1, f2)
# 因为 f2 = inf, 所以 h = f1
return 1 # 只有一个面
if num_faces > 2:
#logger.warning("get_operator 输入数量为{} > 2,将只取前两个面片进行处理".format(num_faces))
node_faces = node_faces[:2]
#if self.edge_index is None or self.edge_type is None:
#logger.warning("edge_index 或 edge_type 未设置")
# 查找这两个面之间的边
mask = ((self.edge_index[0] == node_faces[0]) & (self.edge_index[1] == node_faces[1])) | \
((self.edge_index[0] == node_faces[1]) & (self.edge_index[1] == node_faces[0]))
if not mask.any():
#logger.warning("没有面可以用")
return 3
edge_types = self.edge_type[mask]
# 如果有多条边,返回第一个
return int(edge_types[0].item())
def is_clique(self, node_faces: torch.Tensor) -> bool:
"""检查给定面片集合是否构成完全图
@ -150,7 +176,7 @@ class PatchGraph(nn.Module):
@staticmethod
def from_preprocessed_data(
surf_wcs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
surf_ncs: np.ndarray, # 形状为(N,)的对象数组,每个元素是形状为(M, 3)的float32数组
edgeFace_adj: np.ndarray, # 形状为(num_edges, num_faces)的int32数组
edge_types: np.ndarray, # 形状为(num_edges,)的int32数组
device: torch.device = None
@ -158,7 +184,7 @@ class PatchGraph(nn.Module):
"""从预处理的数据直接构建面片邻接图
参数:
surf_wcs: 世界坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组
surf_ncs: 归一化坐标系下的曲面几何数据形状为(N,)的对象数组每个元素是形状为(M, 3)的float32数组
edgeFace_adj: -面邻接矩阵形状为(num_edges, num_faces)的int32数组1表示边与面相邻
edge_types: 边的类型数组形状为(num_edges,)的int32数组0表示凹边1表示凸边
@ -167,7 +193,7 @@ class PatchGraph(nn.Module):
- edge_index: 形状为(2, num_edges*2)的torch.long张量表示双向边的连接关系
- edge_type: 形状为(num_edges*2,)的torch.long张量表示每条边的类型
"""
num_faces = len(surf_wcs)
num_faces = len(surf_ncs)
graph = PatchGraph(num_faces,device)
# 构建边的索引和类型

137
brep2sdf/train.py

@ -104,7 +104,7 @@ class Trainer:
# 构建面片邻接图
graph = PatchGraph.from_preprocessed_data(
surf_wcs=self.data['surf_wcs'],
surf_ncs=self.data['surf_ncs'],
edgeFace_adj=self.data['edgeFace_adj'],
edge_types=self.data['edge_types'],
device='cuda' if args.octree_cuda else 'cpu'
@ -115,9 +115,15 @@ class Trainer:
dtype=torch.float32,
device=self.device
)
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=6)
logger.gpu_memory_stats("数初始化后")
max_depth = config.model.octree_max_depth
if not args.force_reprocess:
if not self._load_octree():
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
elif self.root.max_depth != max_depth:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
else:
self.build_tree(surf_bbox=surf_bbox, graph=graph,max_depth=max_depth)
logger.gpu_memory_stats("树初始化后")
self.model = Net(
octree=self.root,
@ -140,6 +146,7 @@ class Trainer:
def build_tree(self,surf_bbox, graph, max_depth=9):
logger.info("开始构造八叉树...")
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
@ -155,6 +162,7 @@ class Trainer:
self.root.build_static_tree()
logger.info("complete octree conduction")
self.root.print_tree()
self._save_octree()
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
@ -188,8 +196,80 @@ class Trainer:
# # 返回合并后的边界框
# return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的
def train_epoch_stage1(self, epoch: int):
total_loss = 0.0 # 初始化总损失
for step, surf_points in enumerate(self.data['surf_ncs']): # 定义 step 变量
points = torch.tensor(surf_points, device=self.device)
gt_sdf = torch.zeros(points.shape[0], device=self.device)
normals = None
if args.use_normal:
normals = torch.tensor(self.data["surf_pnt_normals"][step], device=self.device)
# --- 准备模型输入,启用梯度 ---
points.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
self.optimizer.zero_grad()
pred_sdf = self.model.forward_training_volumes(points, step)
if self.debug_mode:
# --- 检查前向传播的输出 ---
logger.gpu_memory_stats("前向传播后")
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
if args.use_normal:
loss, loss_details = self.loss_manager.compute_loss(
points,
normals,
gt_sdf,
pred_sdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details:
logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
# --- 反向传播和优化 ---
try:
loss.backward()
# --- (推荐) 添加梯度裁剪 ---
# 防止梯度爆炸,这可能是导致 inf/nan 的原因之一
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 范数裁剪
self.optimizer.step()
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss
def train_epoch(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
@ -347,7 +427,8 @@ class Trainer:
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch(epoch)
# 验证
'''
@ -445,6 +526,50 @@ class Trainer:
except Exception as e:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
# ... existing code ...
def _save_octree(self):
"""
保存八叉树到文件
八叉树保存路径基于模型名称和配置中的检查点目录
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
# 保存八叉树的根节点
torch.save(self.root, octree_path)
logger.info(f"八叉树已保存到 {octree_path}")
except Exception as e:
logger.error(f"保存八叉树失败: {str(e)}")
def _load_octree(self)->bool:
"""
从文件加载八叉树
尝试从基于模型名称和配置检查点目录的路径加载八叉树
"""
checkpoint_dir = os.path.join(
self.config.train.checkpoint_dir,
self.model_name
)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:
if os.path.exists(octree_path):
# 加载八叉树的根节点
self.root = torch.load(octree_path, weights_only=False)
logger.info(f"八叉树已从 {octree_path} 加载")
return True
else:
logger.warning(f"八叉树文件 {octree_path} 不存在,无法加载。")
except Exception as e:
logger.error(f"加载八叉树失败: {str(e)}")
return False
def main():
# 这里需要初始化配置
config = get_default_config()

Loading…
Cancel
Save