Browse Source

可以加载和保存

final
mckay 2 months ago
parent
commit
1e4c360403
  1. 120
      brep2sdf/networks/encoder.py
  2. 8
      brep2sdf/networks/network.py
  3. 139
      brep2sdf/networks/octree.py
  4. 51
      brep2sdf/train.py

120
brep2sdf/networks/encoder.py

@ -8,7 +8,7 @@ from .octree import OctreeNode
from brep2sdf.config.default_config import get_default_config
from brep2sdf.utils.logger import logger
import numpy as np
'''
class Encoder:
def __init__(self, surf_bbox: torch.Tensor, origin_bbox: torch.Tensor, max_depth: int, feature_dim:int = 64):
"""
@ -59,11 +59,129 @@ class Encoder:
'''
class Encoder(nn.Module):
def __init__(self, octree: OctreeNode, feature_dim: int = 32):
"""
分离后的编码器接收预构建的八叉树
参数:
octree: 预构建的八叉树结构
feature_dim: 特征维度
"""
super().__init__()
self.octree = octree
self.feature_dim = feature_dim
# 为所有叶子节点注册可学习参数
self._init_parameters()
def _init_parameters(self):
"""为所有叶子节点初始化特征参数"""
# 使用字典保存所有参数,避免动态属性
self._leaf_parameters = nn.ParameterDict()
# 递归遍历树结构
def _register_params(node, path=""):
#logger.debug(node.is_leaf())
if node.is_leaf():
param_name = f"leaf_{path}"
self._leaf_parameters[param_name] = nn.Parameter(
torch.randn(8, self.feature_dim) # 8个顶点的特征
)
node.set_param_key(param_name) # 为节点存储参数键
#logger.debug(param_name)
#logger.debug(node.param_key)
else:
for i, child in enumerate(node.child_nodes):
_register_params(child, f"{path}_{i}")
_register_params(self.octree, "root")
def get_leaf_parameter(self, node):
"""获取叶子节点的特征参数"""
return self._leaf_parameters[node.param_key]
def forward(self, query_points: torch.Tensor) -> torch.Tensor:
"""
前向传播处理批量查询点
参数:
query_points: 查询点的位置坐标形状为(batch_size, 3)
返回:
feature_vectors: 查询点的特征向量形状为(batch_size, feature_dim)
"""
batch_size = query_points.shape[0]
features = []
for i in range(batch_size):
# 1. 在八叉树中查找包含该点的叶子节点
leaf_node = self.octree.find_leaf(query_points[i])
#logger.debug(leaf_node.param_key)
# 2. 获取该节点的特征参数
bbox = leaf_node.bbox
node_features = self.get_leaf_parameter(leaf_node)
# 3. 使用三线性插值计算特征
# (这里需要实现你的插值逻辑)
interpolated = self.trilinear_interpolation(
query_points[i], bbox, node_features)
features.append(interpolated)
return torch.stack(features, dim=0)
def trilinear_interpolation(self, query_point: torch.Tensor, bbox, features) -> torch.Tensor:
"""
实现三线性插值
:param query_point: 待插值的点格式为 (x, y, z)
:return: 插值结果形状为 (D,)
"""
# 确保 query_point 和 bbox 在同一设备上
#query_point = query_point.to(self.bbox.device)
# 获取包围盒的最小和最大坐标
min_coords = bbox[:3] # [min_x, min_y, min_z]
max_coords = bbox[3:] # [max_x, max_y, max_z]
# 计算归一化坐标
normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误
x, y, z = normalized_coords.unbind(dim=-1)
# 使用torch.stack避免Python标量转换
wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分
wy = torch.stack([1 - y, y], dim=-1)
wz = torch.stack([1 - z, z], dim=-1)
# 获取8个顶点的特征向量
c = features # 形状为 (8, D)
# 执行三线性插值
# 先对 x 轴插值
c00 = c[0] * wx[0] + c[1] * wx[1]
c01 = c[2] * wx[0] + c[3] * wx[1]
c10 = c[4] * wx[0] + c[5] * wx[1]
c11 = c[6] * wx[0] + c[7] * wx[1]
# 再对 y 轴插值
c0 = c00 * wy[0] + c10 * wy[1]
c1 = c01 * wy[0] + c11 * wy[1]
# 最后对 z 轴插值
result = c0 * wz[0] + c1 * wz[1]
return result
def to(self, device):
super().to(device)
def _move_node(node):
if isinstance(node.bbox, torch.Tensor):
node.bbox = node.bbox.to(device)
for child in node.children:
_move_node(child)
_move_node(self.octree.root)
return self

8
brep2sdf/networks/network.py

@ -53,9 +53,7 @@ from .decoder import Decoder
class Net(nn.Module):
def __init__(self,
surf_bbox,
origin_bbox,
max_depth=4,
octree,
feature_dim=64,
decoder_input_dim=64,
decoder_output_dim=1,
@ -68,9 +66,7 @@ class Net(nn.Module):
# 初始化 Encoder
self.encoder = Encoder(
surf_bbox=surf_bbox, # 使用传入的bbox作为表面包围盒
origin_bbox=origin_bbox, # 使用相同的bbox作为原点包围盒
max_depth=max_depth,
octree=octree,
feature_dim=feature_dim
)

139
brep2sdf/networks/octree.py

@ -28,21 +28,20 @@ def bbox_intersect(bbox1: torch.Tensor, bbox2: torch.Tensor) -> bool:
# 向量化比较
return torch.all((max1 >= min2) & (max2 >= min1))
class OctreeNode:
feature_dim=None
class OctreeNode(nn.Module):
device=None
surf_bbox = None
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, feature_dim:int = None, surf_bbox:torch.Tensor = None):
def __init__(self, bbox: torch.Tensor,face_indices: np.ndarray, max_depth: int = 5, surf_bbox:torch.Tensor = None):
super().__init__()
self.bbox = bbox # 节点的边界框
self.max_depth: int = max_depth # 最大深度,当这个为0时,表示已经到达最大深度,不可再分子节点
self.children: List['OctreeNode'] = [] # 子节点列表
self.child_nodes: List['OctreeNode'] = [] # 子节点列表
self.face_indices = face_indices
self.patch_feature_volume = None # 补丁特征体积,only leaf has
self.param_key = None
#self.patch_feature_volume = None # 补丁特征体积,only leaf has
self._is_leaf = True
#print(f"box shape: {self.bbox.shape}")
if feature_dim is not None:
OctreeNode.feature_dim = feature_dim
if surf_bbox is not None:
if not isinstance(surf_bbox, torch.Tensor):
raise TypeError(
@ -56,13 +55,15 @@ class OctreeNode:
OctreeNode.device = bbox.device
def is_leaf(self):
# Check if self.children is None before calling len()
# Check if self.child——nodes is None before calling len()
return self._is_leaf
def set_param_key(self, k):
self.param_key = k
def conduct_tree(self):
if self.max_depth <= 0 or self.face_indices.shape[0] <= 2:
# 达到最大深度 or 一个单元格至多只有两个面
self.patch_feature_volume = nn.Parameter(torch.randn(8, OctreeNode.feature_dim, device=OctreeNode.device))
return
self.subdivide()
@ -102,7 +103,7 @@ class OctreeNode:
])
# 为每个子包围盒创建子节点,并分配相交的面
self.children = []
self.child_nodes = []
for bbox in child_bboxes:
# 找到与子包围盒相交的面
intersecting_faces = []
@ -118,7 +119,7 @@ class OctreeNode:
max_depth=self.max_depth - 1
)
child_node.conduct_tree()
self.children.append(child_node)
self.child_nodes.append(child_node)
self._is_leaf = False
@ -143,6 +144,25 @@ class OctreeNode:
return index.unsqueeze(0)
def find_leaf(self, query_point:torch.Tensor):
# 从根节点开始递归查找包含该点的叶子节点
if self._is_leaf:
return self
else:
index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查
return self.child_nodes[index].find_leaf(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
def get_feature_vector(self, query_point:torch.Tensor):
"""
预测给定点的 SDF
@ -158,7 +178,7 @@ class OctreeNode:
index = self.get_child_index(query_point)
try:
# 直接访问子节点,不进行显式检查
return self.children[index].get_feature_vector(query_point)
return self.child_nodes[index].get_feature_vector(query_point)
except IndexError as e:
# 记录错误日志并重新抛出异常
logger.error(
@ -170,46 +190,7 @@ class OctreeNode:
raise e
def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor:
"""
实现三线性插值
:param query_point: 待插值的点格式为 (x, y, z)
:return: 插值结果形状为 (D,)
"""
# 确保 query_point 和 bbox 在同一设备上
#query_point = query_point.to(self.bbox.device)
# 获取包围盒的最小和最大坐标
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算归一化坐标
normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误
x, y, z = normalized_coords.unbind(dim=-1)
# 使用torch.stack避免Python标量转换
wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分
wy = torch.stack([1 - y, y], dim=-1)
wz = torch.stack([1 - z, z], dim=-1)
# 获取8个顶点的特征向量
c = self.patch_feature_volume # 形状为 (8, D)
# 执行三线性插值
# 先对 x 轴插值
c00 = c[0] * wx[0] + c[1] * wx[1]
c01 = c[2] * wx[0] + c[3] * wx[1]
c10 = c[4] * wx[0] + c[5] * wx[1]
c11 = c[6] * wx[0] + c[7] * wx[1]
# 再对 y 轴插值
c0 = c00 * wy[0] + c10 * wy[1]
c1 = c01 * wy[0] + c11 * wy[1]
# 最后对 z 轴插值
result = c0 * wz[0] + c1 * wz[1]
return result
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
@ -231,46 +212,32 @@ class OctreeNode:
# 打印面片信息(如果有)
if self.face_indices is not None:
print(f"{indent} Face indices: {self.face_indices.tolist()}")
print(f"{indent} len children: {len(self.children)}")
print(f"{indent} len child_nodes: {len(self.child_nodes)}")
# 递归打印子节点
for i, child in enumerate(self.children):
for i, child in enumerate(self.child_nodes):
print(f"{indent} Child {i}:")
child.print_tree(depth + 1, max_print_depth)
# 保存
def state_dict(self):
"""返回节点及其子树的state_dict"""
state = {
'bbox': self.bbox,
'max_depth': self.max_depth,
'face_indices': self.face_indices,
'is_leaf': self._is_leaf
}
if self._is_leaf:
state['patch_feature_volume'] = self.patch_feature_volume
else:
state['children'] = [child.state_dict() for child in self.children]
def __getstate__(self):
"""支持pickle序列化"""
return self._serialize_node(self)
return state
def __setstate__(self, state):
"""支持pickle反序列化"""
self = self._deserialize_node(state)
def load_state_dict(self, state_dict):
"""从state_dict加载节点状态"""
self.bbox = state_dict['bbox']
self.max_depth = state_dict['max_depth']
self.face_indices = state_dict['face_indices']
self._is_leaf = state_dict['is_leaf']
def _serialize_node(self, node):
return {
'bbox': node.bbox,
'is_leaf': node._is_leaf,
'child_nodes': [self._serialize_node(c) for c in node.child_nodes],
'param_key': node.param_key
}
if self._is_leaf:
self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume'])
else:
self.children = []
for child_state in state_dict['children']:
child = OctreeNode(
bbox=child_state['bbox'],
face_indices=child_state['face_indices'],
max_depth=child_state['max_depth']
)
child.load_state_dict(child_state)
self.children.append(child)
def _deserialize_node(self, data):
node = OctreeNode(data['bbox'], 0) # max_depth会在encoder中重建
node._is_leaf = data['is_leaf']
node.param_key = data['param_key']
node.child_nodes = [self._deserialize_node(c) for c in data['child_nodes']]
return node

51
brep2sdf/train.py

@ -1,4 +1,5 @@
import torch
from torch.serialization import add_safe_globals
import torch.optim as optim
import time
import os
@ -9,6 +10,7 @@ from brep2sdf.config.default_config import get_default_config
from brep2sdf.data.data import load_brep_file,load_sdf_file
from brep2sdf.data.pre_process import process_single_step
from brep2sdf.networks.network import Net
from brep2sdf.networks.octree import OctreeNode
from brep2sdf.utils.logger import logger
def prepare_sdf_data(surf_data, max_points=100000, device='cuda'):
@ -62,10 +64,13 @@ class Trainer:
dtype=torch.float32,
device=self.device
)
bbox = self._calculate_global_bbox(surf_bbox)
self.build_tree(surf_bbox=surf_bbox, max_depth=4)
self.model = Net(
surf_bbox=surf_bbox,
origin_bbox=bbox,
octree=self.root,
feature_dim=64
).to(self.device)
@ -77,6 +82,20 @@ class Trainer:
)
def build_tree(self,surf_bbox, max_depth=6):
num_faces = surf_bbox.shape[0]
bbox = self._calculate_global_bbox(surf_bbox)
self.root = OctreeNode(
bbox=bbox,
face_indices=np.arange(num_faces), # 初始包含所有面
max_depth=max_depth,
surf_bbox=surf_bbox
)
#print(surf_bbox)
logger.info("starting octree conduction")
self.root.conduct_tree()
logger.info("complete octree conduction")
#self.root.print_tree(0)
def _calculate_global_bbox(self, surf_bbox: torch.Tensor) -> torch.Tensor:
"""
@ -103,6 +122,7 @@ class Trainer:
# 返回合并后的边界框
return torch.cat([global_min, global_max])
def train_epoch(self, epoch: int) -> float:
self.model.train()
total_loss = 0.0
@ -153,7 +173,7 @@ class Trainer:
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
"""
for epoch in range(1, self.config.train.num_epochs + 1):
# 训练一个epoch
train_loss = self.train_epoch(epoch)
@ -182,7 +202,18 @@ class Trainer:
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model()
"""
self.test_load()
def test_load(self):
model = self._load_checkpoint("/home/wch/brep2sdf/brep2sdf/00000054.pt")
model.eval()
logger.debug(model)
example_input = torch.rand(10, 3, device=self.device)
#logger.debug(model.encoder.octree.bbox)
logger.debug(f"points: {example_input}")
sdfs= model(example_input)
logger.debug(f"sdfs:{sdfs}")
def _tracing_model(self):
"""保存模型"""
@ -195,13 +226,8 @@ class Trainer:
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1 # 从下一轮开始
best_loss = checkpoint['loss']
logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
return start_epoch, best_loss
model = torch.load(checkpoint_path)
return model
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""
@ -211,6 +237,7 @@ class Trainer:
)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir,f"epoch_{epoch:03d}.pth")
'''
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
@ -218,6 +245,8 @@ class Trainer:
'loss': train_loss,
'config': self.config
}, checkpoint_path)
'''
torch.save(self.model,checkpoint_path)
def main():
# 这里需要初始化配置

Loading…
Cancel
Save