Browse Source

优化写法,兼容 torch

final
mckay 2 months ago
parent
commit
a7541c9da3
  1. 183
      brep2sdf/networks/octree.py
  2. 37
      brep2sdf/train.py

183
brep2sdf/networks/octree.py

@ -69,24 +69,37 @@ class OctreeNode:
def subdivide(self):
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
#min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
# 使用索引操作替代解包
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算中间点
mid_x = (min_x + max_x) / 2
mid_y = (min_y + max_y) / 2
mid_z = (min_z + max_z) / 2
mid_coords = (min_coords + max_coords) / 2
# 提取 min_x, min_y, min_z, mid_x, mid_y, mid_z
min_x, min_y, min_z = min_coords[0], min_coords[1], min_coords[2]
mid_x, mid_y, mid_z = mid_coords[0], mid_coords[1], mid_coords[2]
max_x, max_y, max_z = max_coords[0], max_coords[1], max_coords[2]
# 生成 8 个子包围盒
child_bboxes = torch.tensor([
[min_x, min_y, min_z, mid_x, mid_y, mid_z], # 前下左
[mid_x, min_y, min_z, max_x, mid_y, mid_z], # 前下右
[min_x, mid_y, min_z, mid_x, max_y, mid_z], # 前上左
[mid_x, mid_y, min_z, max_x, max_y, mid_z], # 前上右
[min_x, min_y, mid_z, mid_x, mid_y, max_z], # 后下左
[mid_x, min_y, mid_z, max_x, mid_y, max_z], # 后下右
[min_x, mid_y, mid_z, mid_x, max_y, max_z], # 后上左
[mid_x, mid_y, mid_z, max_x, max_y, max_z] # 后上右
], dtype=torch.float32, device=OctreeNode.device)
child_bboxes = torch.stack([
torch.cat([min_coords, mid_coords]), # 前下左
torch.cat([torch.tensor([mid_x, min_y, min_z], device=self.bbox.device),
torch.tensor([max_x, mid_y, mid_z], device=self.bbox.device)]), # 前下右
torch.cat([torch.tensor([min_x, mid_y, min_z], device=self.bbox.device),
torch.tensor([mid_x, max_y, mid_z], device=self.bbox.device)]), # 前上左
torch.cat([torch.tensor([mid_x, mid_y, min_z], device=self.bbox.device),
torch.tensor([max_x, max_y, mid_z], device=self.bbox.device)]), # 前上右
torch.cat([torch.tensor([min_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, mid_y, max_z], device=self.bbox.device)]), # 后下左
torch.cat([torch.tensor([mid_x, min_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, mid_y, max_z], device=self.bbox.device)]), # 后下右
torch.cat([torch.tensor([min_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([mid_x, max_y, max_z], device=self.bbox.device)]), # 后上左
torch.cat([torch.tensor([mid_x, mid_y, mid_z], device=self.bbox.device),
torch.tensor([max_x, max_y, max_z], device=self.bbox.device)]) # 后上右
])
# 为每个子包围盒创建子节点,并分配相交的面
self.children = []
@ -98,7 +111,7 @@ class OctreeNode:
if bbox_intersect(bbox, face_bbox):
intersecting_faces.append(face_idx)
#print(f"{bbox}: {intersecting_faces}")
if intersecting_faces:
child_node = OctreeNode(
bbox=bbox,
face_indices=np.array(intersecting_faces),
@ -112,29 +125,23 @@ class OctreeNode:
def get_child_index(self, query_point: torch.Tensor) -> int:
"""
计算点所在子节点的索引
:param point: 待检查的点格式为 (x, y, z)
:param query_point: 待检查的点格式为 (x, y, z)
:return: 子节点的索引范围从 0 7
"""
#print(query_point)
x, y, z = query_point
#logger.info(f"query_point: {query_point}")
#logger.info(f"box: {self.bbox}")
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
# 确保 query_point 和 bbox 在同一设备上
query_point = query_point.to(self.bbox.device)
# 提取 bbox 的最小和最大坐标
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算中间点
mid_x = (min_x + max_x) / 2
mid_y = (min_y + max_y) / 2
mid_z = (min_z + max_z) / 2
index = 0
if x >= mid_x: # 修正变量名
index += 1
if y >= mid_y: # 修正变量名
index += 2
if z >= mid_z: # 修正变量名
index += 4
#logger.info(f"index: {index}")
return index
mid_coords = (min_coords + max_coords) / 2
# 使用布尔比较结果计算索引
index = ((query_point >= mid_coords) << torch.arange(3, device=self.bbox.device)).sum()
return index.unsqueeze(0)
def get_feature_vector(self, query_point:torch.Tensor):
"""
@ -150,58 +157,59 @@ class OctreeNode:
else:
index = self.get_child_index(query_point)
try:
if index < 0 or index >= len(self.children):
raise IndexError(
f"Child index {index} out of range (0-{len(self.children)-1}) "
f"for query point {query_point.cpu().numpy().tolist()}. "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}"
f"dept info: {self.max_depth}"
)
# 直接访问子节点,不进行显式检查
return self.children[index].get_feature_vector(query_point)
except IndexError as e:
logger.error(str(e))
# 记录错误日志并重新抛出异常
logger.error(
f"Error accessing child node: {e}. "
f"Query point: {query_point.cpu().numpy().tolist()}, "
f"Node bbox: {self.bbox.cpu().numpy().tolist()}, "
f"Depth info: {self.max_depth}"
)
raise e
def trilinear_interpolation(self, query_point:torch.Tensor) -> torch.Tensor:
"""
使用三线性插值从补丁特征体积中获取查询点的特征向量
:param query_point: 查询点的位置坐标
:return: 插值后的特征向量
"""
"""三线性插值"""
def trilinear_interpolation(self, query_point: torch.Tensor) -> torch.Tensor:
"""
实现三线性插值
:param query_point: 待插值的点格式为 (x, y, z)
:return: 插值结果形状为 (D,)
"""
# 获取包围盒的边界
min_x, min_y, min_z, max_x, max_y, max_z = self.bbox
# 确保 query_point 和 bbox 在同一设备上
#query_point = query_point.to(self.bbox.device)
# 获取包围盒的最小和最大坐标
min_coords = self.bbox[:3] # [min_x, min_y, min_z]
max_coords = self.bbox[3:] # [max_x, max_y, max_z]
# 计算归一化坐标
x = (query_point[0] - min_x) / (max_x - min_x)
y = (query_point[1] - min_y) / (max_y - min_y)
z = (query_point[2] - min_z) / (max_z - min_z)
normalized_coords = (query_point - min_coords) / (max_coords - min_coords + 1e-8) # 防止除零错误
x, y, z = normalized_coords.unbind(dim=-1)
# 使用torch.stack避免Python标量转换
wx = torch.stack([1 - x, x], dim=-1) # 保持自动微分
wy = torch.stack([1 - y, y], dim=-1)
wz = torch.stack([1 - z, z], dim=-1)
# 获取8个顶点的特征向量
c000 = self.patch_feature_volume[0]
c100 = self.patch_feature_volume[1]
c010 = self.patch_feature_volume[2]
c110 = self.patch_feature_volume[3]
c001 = self.patch_feature_volume[4]
c101 = self.patch_feature_volume[5]
c011 = self.patch_feature_volume[6]
c111 = self.patch_feature_volume[7]
c = self.patch_feature_volume # 形状为 (8, D)
# 执行三线性插值
c00 = c000 * (1 - x) + c100 * x
c01 = c001 * (1 - x) + c101 * x
c10 = c010 * (1 - x) + c110 * x
c11 = c011 * (1 - x) + c111 * x
# 先对 x 轴插值
c00 = c[0] * wx[0] + c[1] * wx[1]
c01 = c[2] * wx[0] + c[3] * wx[1]
c10 = c[4] * wx[0] + c[5] * wx[1]
c11 = c[6] * wx[0] + c[7] * wx[1]
c0 = c00 * (1 - y) + c10 * y
c1 = c01 * (1 - y) + c11 * y
# 再对 y 轴插值
c0 = c00 * wy[0] + c10 * wy[1]
c1 = c01 * wy[0] + c11 * wy[1]
return c0 * (1 - z) + c1 * z
# 最后对 z 轴插值
result = c0 * wz[0] + c1 * wz[1]
return result
def print_tree(self, depth: int = 0, max_print_depth: int = None) -> None:
@ -229,3 +237,40 @@ class OctreeNode:
for i, child in enumerate(self.children):
print(f"{indent} Child {i}:")
child.print_tree(depth + 1, max_print_depth)
# 保存
def state_dict(self):
"""返回节点及其子树的state_dict"""
state = {
'bbox': self.bbox,
'max_depth': self.max_depth,
'face_indices': self.face_indices,
'is_leaf': self._is_leaf
}
if self._is_leaf:
state['patch_feature_volume'] = self.patch_feature_volume
else:
state['children'] = [child.state_dict() for child in self.children]
return state
def load_state_dict(self, state_dict):
"""从state_dict加载节点状态"""
self.bbox = state_dict['bbox']
self.max_depth = state_dict['max_depth']
self.face_indices = state_dict['face_indices']
self._is_leaf = state_dict['is_leaf']
if self._is_leaf:
self.patch_feature_volume = nn.Parameter(state_dict['patch_feature_volume'])
else:
self.children = []
for child_state in state_dict['children']:
child = OctreeNode(
bbox=child_state['bbox'],
face_indices=child_state['face_indices'],
max_depth=child_state['max_depth']
)
child.load_state_dict(child_state)
self.children.append(child)

37
brep2sdf/train.py

@ -35,7 +35,7 @@ class Trainer:
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_name = os.path.basename(input_step).replace(".step", "")
self.model_name = os.path.basename(input_step).split('_')[0]
self.base_name = self.model_name + ".xyz"
data_path = os.path.join("/home/wch/brep2sdf/data/output_data",self.base_name)
if os.path.exists(data_path):
@ -110,6 +110,7 @@ class Trainer:
# 获取数据并移动到设备
points = self.sdf_data[:,0:3]
points.requires_grad_(True)
gt_sdf = self.sdf_data[:,3]
# 前向传播
@ -180,23 +181,27 @@ class Trainer:
total_time = time.time() - start_time
logger.info(f'Training completed in {total_time//3600:.0f}h {(total_time%3600)//60:.0f}m {total_time%60:.2f}s')
logger.info(f'Best validation loss: {best_val_loss:.6f}')
self._tracing_model()
def _save_model(self, epoch: int, val_loss: float):
"""保存最佳模型"""
save_path = os.path.join(
self.config.train.checkpoint_dir,
self.config.train.best_model_name.format(
model_name=config.train.model_name
)
)
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': val_loss,
'config': self.config
}, save_path)
def _tracing_model(self):
"""保存模型"""
self.model.eval()
example_input = torch.rand(10, 3, device=self.device)
# 3. 在no_grad上下文中执行追踪
with torch.no_grad():
traced_model = torch.jit.trace(self.model, example_input)
torch.jit.save(traced_model, f"{self.model_name}.pt")
def _load_checkpoint(self, checkpoint_path):
"""从检查点恢复训练状态"""
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1 # 从下一轮开始
best_loss = checkpoint['loss']
logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
return start_epoch, best_loss
def _save_checkpoint(self, epoch: int, train_loss: float):
"""保存训练检查点"""

Loading…
Cancel
Save