Browse Source

背景场训练

final
mckay 1 month ago
parent
commit
2bed5b89cd
  1. 7
      brep2sdf/batch_train.py
  2. 48
      brep2sdf/networks/encoder.py
  3. 7
      brep2sdf/networks/feature_volume.py
  4. 24
      brep2sdf/networks/loss.py
  5. 64
      brep2sdf/networks/network.py
  6. 2
      brep2sdf/networks/sample.py
  7. 314
      brep2sdf/train.py
  8. 133
      brep2sdf/utils/load.py
  9. 18
      data/scripts/pre_processing1.py

7
brep2sdf/batch_train.py

@ -70,7 +70,8 @@ def batch_train(args):
common_train_args = [
"--use-normal",
"--only-zero-surface",
#"--force-reprocess",
"--octree-cuda",
"--force-reprocess",
# 可以添加更多通用参数
]
if args.train_args:
@ -249,8 +250,8 @@ def batch_Iso(args):
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}")
def main(args):
#batch_train(args)
batch_Iso(args)
batch_train(args)
#batch_Iso(args)
if __name__ == '__main__':

48
brep2sdf/networks/encoder.py

@ -29,10 +29,17 @@ class Encoder(nn.Module):
) for i, bbox in enumerate(volume_bboxs)
])
self.background = PatchFeatureVolume(
bbox=torch.Tensor([-0.5, -0.5, -0.5, 0.5, 0.5, 0.5]), # 修正后的归一化bbox
resolution=int(resolutions.max()) * 2,
feature_dim=feature_dim
self.background = self.simple_encoder = nn.Sequential(
nn.Linear(3, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, feature_dim)
)
print(f"Initialized {len(self.feature_volumes)} feature volumes with resolutions: {resolutions.tolist()}")
print(f"Model parameters memory usage: {sum(p.numel() * p.element_size() for p in self.parameters()) / 1024**2:.2f} MB")
@ -88,6 +95,20 @@ class Encoder(nn.Module):
return all_features
def forward_background(self, query_points: torch.Tensor) -> torch.Tensor:
"""
修改后的前向传播返回所有关联volume的特征矩阵
参数:
query_points: 查询点坐标 (B, 3)
返回:
特征张量 (B, D)
"""
background_features = self.background.forward(query_points) # (B, D)
return background_features
@torch.jit.export
def forward_training_volumes(self, surf_points: torch.Tensor, patch_id: int) -> torch.Tensor:
"""
@ -138,3 +159,22 @@ class Encoder(nn.Module):
_move_node(child)
_move_node(self.octree.root)
return self
def freeze_stage1(self):
for volume in self.feature_volumes:
for param in volume.parameters():
param.requires_grad = False
for param in self.background.parameters():
param.requires_grad = False
def freeze_stage2(self):
for volume in self.feature_volumes:
for param in volume.parameters():
param.requires_grad = True
for param in self.background.parameters():
param.requires_grad = False
def unfreeze(self):
for volume in self.feature_volumes:
for param in volume.parameters():
param.requires_grad = True
for param in self.background.parameters():
param.requires_grad = True

7
brep2sdf/networks/feature_volume.py

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
class PatchFeatureVolume(nn.Module):
def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=64, padding_ratio=0.05):
def __init__(self, bbox: torch.Tensor, resolution=64, feature_dim=8, padding_ratio=0.05):
super(PatchFeatureVolume, self).__init__()
# 将输入bbox转换为[min, max]格式
self.resolution = resolution
@ -19,8 +19,9 @@ class PatchFeatureVolume(nn.Module):
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
self.register_buffer('grid', torch.stack([grid_x, grid_y, grid_z], dim=-1))
# 初始化特征向量
self.feature_volume = nn.Parameter(torch.randn(resolution, resolution, resolution, feature_dim))
# 初始化特征向量为很小的值,使用较小的标准差
self.feature_volume = nn.Parameter(torch.empty(resolution, resolution, resolution, feature_dim))
torch.nn.init.normal_(self.feature_volume, mean=0.0, std=0.01) # 标准差设置为 0.01,可根据需要调整
def _expand_bbox(self, min_coords, max_coords, ratio):
# 扩展包围盒范围

24
brep2sdf/networks/loss.py

@ -5,14 +5,15 @@ from brep2sdf.utils.logger import logger
class LossManager:
def __init__(self, ablation, **condition_kwargs):
self.weights = {
"manifold": 10,
"manifold": 1,
"feature_manifold": 1, # 原文里面和manifold的权重是一样的
"normals": 1,
"eikonal": 1,
"offsurface": 1,
"consistency": 1,
"correction": 1,
"psdf": 10
"psdf": 1,
"psdf_sign_loss": 0
}
self.condition_kwargs = condition_kwargs
self.ablation = ablation # 消融实验用
@ -111,6 +112,21 @@ class LossManager:
correction_loss = (a_correction * torch.abs(mnfld_pred - all_fi[:,0])[mismatch_id]).mean() # 计算修正损失
return correction_loss
def psdf_loss(self, pred_sdfs, gt_sdfs):
# 定义符号相同和不同时的权重
weight_same_sign = 1.0 # 符号相同时的权重
weight_different_sign = 10.0 # 符号不同时的权重
# 判断符号是否相同
same_sign = (pred_sdfs * gt_sdfs) >= 0
# 根据符号设置权重
weights = torch.where(same_sign, weight_same_sign, weight_different_sign)
squared_diff = torch.pow(pred_sdfs - gt_sdfs, 2)
weighted_squared_diff = weights * squared_diff
return torch.mean(weighted_squared_diff)
def compute_loss(self,
mnfld_pnts,
@ -150,7 +166,7 @@ class LossManager:
# 计算修正损失
#correction_loss = self.correction_loss(mnfld_pnts, mnfld_pred, all_fi)
psdf_loss = self.position_loss(nonmnfld_pred, psdfs)
psdf_loss = self.psdf_loss(nonmnfld_pred, psdfs)
# 汇总损失
loss_details = {
@ -158,7 +174,7 @@ class LossManager:
"normals": self.weights["normals"] * normals_loss,
"eikonal": self.weights["eikonal"] * eikonal_loss,
"offsurface": self.weights["offsurface"] * offsurface_loss,
"psdf":self.weights["psdf"] * psdf_loss
"psdf":self.weights["psdf"] * psdf_loss,
}
# 计算总损失

64
brep2sdf/networks/network.py

@ -56,7 +56,7 @@ class Net(nn.Module):
def __init__(self,
octree,
volume_bboxs,
feature_dim=64,
feature_dim=8,
decoder_output_dim=1,
decoder_hidden_dim=256,
decoder_num_layers=4,
@ -87,8 +87,7 @@ class Net(nn.Module):
output = f_i[:,0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
valid_mask = face_indices_mask.bool() # 确保是布尔类型 (B, P)
masked_f_i = torch.where(valid_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
@ -108,7 +107,6 @@ class Net(nn.Module):
if mask_convex.any():
output[mask_convex] = torch.min(padded_f_i[mask_convex], dim=1).values
logger.debug("step over")
#logger.gpu_memory_stats("combine后")
return output
@ -138,6 +136,45 @@ class Net(nn.Module):
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_background(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
# 编码
feature_vectors = self.encoder.forward_background(query_points)
# 解码
h = self.decoder.forward_training_volumes(feature_vectors) # (B, D)
return h
@torch.jit.export
def forward_without_octree(self, query_points,face_indices_mask,operator):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape)
# 解码
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):
"""
@ -154,12 +191,27 @@ class Net(nn.Module):
def gradient(inputs, outputs):
# 问题点1:inputs可能包含非坐标特征
# 问题点2:未处理batch维度特殊情况
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
# 改进计算方式
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0][:, -3:]
return points_grad
only_inputs=True,
allow_unused=True # 新增异常处理
)[0]
# 修正维度切片方式
if points_grad is None:
return torch.zeros_like(inputs[:, -3:]) # 处理空梯度情况
# 添加安全截取和归一化
coord_grad = points_grad[:, -3:] if points_grad.shape[1] >=3 else points_grad
coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-6) # 安全归一化
return coord_grad

2
brep2sdf/networks/sample.py

@ -3,7 +3,7 @@ import torch
class NormalPerPoint():
def __init__(self, global_sigma, local_sigma=0.5):
def __init__(self, global_sigma, local_sigma=0.001):
self.global_sigma = global_sigma
self.local_sigma = local_sigma

314
brep2sdf/train.py

@ -75,6 +75,8 @@ class Trainer:
self.data = process_single_step(step_path=input_step, output_path=data_path, sample_normal_vector=args.use_normal,sample_sdf_points=not args.only_zero_surface)
logger.gpu_memory_stats("数据预处理后")
self.train_surf_ncs = torch.tensor(self.data["train_surf_ncs"],dtype=torch.float32,device=self.device) #
# 将曲面点云列表转换为 (N*M, 4) 数组
surfs = self.data["surf_ncs"]
@ -98,6 +100,7 @@ class Trainer:
else:
self.sdf_data = surface_sdf_data
print_data_distribution(self.sdf_data)
logger.debug(self.sdf_data.shape)
logger.gpu_memory_stats("SDF数据准备后")
# 初始化数据集
#self.brep_data = load_brep_file(self.config.data.pkl_path)
@ -200,7 +203,7 @@ class Trainer:
# # 返回合并后的边界框
# return torch.cat([global_min, global_max])
# return [-0.5,] # 这个是错误的
def train_epoch_stage1(self, epoch: int):
def train_epoch_stage1_(self, epoch: int):
total_loss = 0.0
total_loss_details = {
"manifold": 0.0,
@ -293,8 +296,303 @@ class Trainer:
return total_loss # 返回平均损失而非累计值
def train_epoch_stage1(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.train_surf_ncs is None:
logger.error(f"Epoch {epoch}: self.train_surf_ncs is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 # 设置合适的batch大小
# 数据处理
# manfld
_mnfld_pnts = self.train_surf_ncs[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.train_surf_ncs[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.train_surf_ncs[:, -1].clone().detach() # 取最后一列作为SDF真值
# 检查是否需要重新计算缓存
if epoch % 10 == 1 or self.cached_train_data is None:
# 计算流形点的掩码和操作符
# 生成非流形点
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
# 更新缓存
self.cached_train_data = {
"nonmnfld_pnts": _nonmnfld_pnts,
"psdf": _psdf,
}
else:
# 从缓存中读取数据
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
_psdf = self.cached_train_data["psdf"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
# 将数据分成多个batch
num_points = self.train_surf_ncs.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
# 获取当前batch的数据
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
# 非流形点使用缓存数据(整个batch共享)
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
psdf = _psdf[start_idx:end_idx]
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model.forward_background(
mnfld_pnts
)
nonmnfld_pred = self.model.forward_background(
nonmnfld_pnts
)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
logger.print_tensor_stats("psdf",psdf)
logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
def train_epoch(self, epoch: int) -> float:
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
if epoch % 100 == 0:
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch_stage2(self, epoch: int) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
if self.sdf_data is None:
logger.error(f"Epoch {epoch}: self.sdf_data is None. Cannot train.")
return float('inf')
self.model.train()
total_loss = 0.0
step = 0 # 如果你的训练是分批次的,这里应该用批次索引
batch_size = 8192 * 2 # 设置合适的batch大小
# 数据处理
# manfld
_mnfld_pnts = self.sdf_data[:, 0:3].clone().detach() # 取前3列作为点
_normals = self.sdf_data[:, 3:6].clone().detach() # 取中间3列作为法线
_gt_sdf = self.sdf_data[:, -1].clone().detach() # 取最后一列作为SDF真值
# 检查是否需要重新计算缓存
if epoch % 10 == 1 or self.cached_train_data is None:
# 计算流形点的掩码和操作符
_, _mnfld_face_indices_mask, _mnfld_operator = self.root.forward(_mnfld_pnts)
# 生成非流形点
_nonmnfld_pnts, _psdf = self.sampler.get_norm_points(_mnfld_pnts, _normals)
_, _nonmnfld_face_indices_mask, _nonmnfld_operator = self.root.forward(_nonmnfld_pnts)
# 更新缓存
self.cached_train_data = {
"mnfld_face_indices_mask": _mnfld_face_indices_mask,
"mnfld_operator": _mnfld_operator,
"nonmnfld_pnts": _nonmnfld_pnts,
"psdf": _psdf,
"nonmnfld_face_indices_mask": _nonmnfld_face_indices_mask,
"nonmnfld_operator": _nonmnfld_operator
}
else:
# 从缓存中读取数据
_mnfld_face_indices_mask = self.cached_train_data["mnfld_face_indices_mask"]
_mnfld_operator = self.cached_train_data["mnfld_operator"]
_nonmnfld_pnts = self.cached_train_data["nonmnfld_pnts"]
_psdf = self.cached_train_data["psdf"]
_nonmnfld_face_indices_mask = self.cached_train_data["nonmnfld_face_indices_mask"]
_nonmnfld_operator = self.cached_train_data["nonmnfld_operator"]
logger.debug((_mnfld_pnts, _nonmnfld_pnts, _psdf))
# 将数据分成多个batch
num_points = self.sdf_data.shape[0]
num_batches = (num_points + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, num_points)
# 获取当前batch的数据
mnfld_pnts = _mnfld_pnts[start_idx:end_idx] # 流形点
gt_sdf = _gt_sdf[start_idx:end_idx] # SDF真值
normals = _normals[start_idx:end_idx] if args.use_normal else None # 法线
# 非流形点使用缓存数据(整个batch共享)
nonmnfld_pnts = _nonmnfld_pnts[start_idx:end_idx]
psdf = _psdf[start_idx:end_idx]
# --- 准备模型输入,启用梯度 ---
mnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
nonmnfld_pnts.requires_grad_(True) # 在检查之后启用梯度
# --- 前向传播 ---
mnfld_pred = self.model.forward_without_octree(
mnfld_pnts,
_mnfld_face_indices_mask[start_idx:end_idx],
_mnfld_operator[start_idx:end_idx]
)
nonmnfld_pred = self.model.forward_without_octree(
nonmnfld_pnts,
_nonmnfld_face_indices_mask[start_idx:end_idx],
_nonmnfld_operator[start_idx:end_idx]
)
#logger.print_tensor_stats("psdf",psdf)
#logger.print_tensor_stats("nonmnfld_pred",nonmnfld_pred)
# --- 计算损失 ---
loss = torch.tensor(float('nan'), device=self.device) # 初始化为 NaN 以防计算失败
loss_details = {}
try:
# --- 3. 检查损失计算前的输入 ---
# (points 已经启用梯度,不需要再次检查inf/nan,但检查pred_sdf和gt_sdf)
#if check_tensor(pred_sdf, "Predicted SDF (Loss Input)", epoch, step): raise ValueError("Bad pred_sdf before loss")
#if check_tensor(gt_sdf, "GT SDF (Loss Input)", epoch, step): raise ValueError("Bad gt_sdf before loss")
if args.use_normal:
# 检查法线和带梯度的点
#if check_tensor(normals, "Normals (Loss Input)", epoch, step): raise ValueError("Bad normals before loss")
#if check_tensor(points, "Points (Loss Input)", epoch, step): raise ValueError("Bad points before loss")
logger.gpu_memory_stats("计算损失前")
loss, loss_details = self.loss_manager.compute_loss(
mnfld_pnts,
nonmnfld_pnts,
normals, # 传递检查过的 normals
gt_sdf,
mnfld_pred,
nonmnfld_pred,
psdf
)
else:
loss = torch.nn.functional.mse_loss(pred_sdf, gt_sdf)
# --- 4. 检查损失计算结果 ---
if self.debug_mode:
if check_tensor(loss, "Calculated Loss", epoch, step):
logger.error(f"Epoch {epoch} Step {step}: Loss calculation resulted in inf/nan.")
if loss_details: logger.error(f"Loss Details: {loss_details}")
return float('inf') # 如果损失无效,停止这个epoch
except Exception as loss_e:
logger.error(f"Epoch {epoch} Step {step}: Error during loss calculation: {loss_e}", exc_info=True)
return float('inf') # 如果计算出错,停止这个epoch
logger.gpu_memory_stats("损失计算后")
# --- 反向传播和优化 ---
try:
# 反向传播
self.scheduler.optimizer.zero_grad() # 清空梯度
loss.backward() # 反向传播
self.scheduler.optimizer.step() # 更新参数
self.scheduler.step(loss,epoch)
except Exception as backward_e:
logger.error(f"Epoch {epoch} Step {step}: Error during backward pass or optimizer step: {backward_e}", exc_info=True)
# 如果你想看是哪个操作导致的,可以启用 anomaly detection
# torch.autograd.set_detect_anomaly(True) # 放在训练开始前
return float('inf') # 如果反向传播或优化出错,停止这个epoch
# --- 记录和累加损失 ---
current_loss = loss.item()
if not np.isfinite(current_loss): # 再次确认损失是有效的数值
logger.error(f"Epoch {epoch} Step {step}: Loss item is not finite ({current_loss}).")
return float('inf')
total_loss += current_loss
del loss
torch.cuda.empty_cache()
# 记录训练进度 (只记录有效的损失)
logger.info(f'Train Epoch: {epoch:4d}]\t'
f'Loss: {current_loss:.6f}')
if loss_details: logger.info(f"Loss Details: {loss_details}")
return total_loss # 对于单批次训练,直接返回当前损失
def train_epoch(self, epoch: int,resample:bool=True) -> float:
# --- 1. 检查输入数据 ---
# 注意:假设 self.sdf_data 包含 [x, y, z, nx, ny, nz, sdf] (7列) 或 [x, y, z, sdf] (4列)
# 并且 SDF 值总是在最后一列
@ -447,16 +745,19 @@ class Trainer:
best_val_loss = float('inf')
logger.info("Starting training...")
start_time = time.time()
self.cached_train_data=None
start_epoch = 1
if args.resume_checkpoint_path:
start_epoch = self._load_checkpoint(args.resume_checkpoint_path)
logger.info(f"Loaded model from {args.resume_checkpoint_path}")
self.model.encoder.freeze_stage1()
for epoch in range(start_epoch, self.config.train.num_epochs + 1):
# 训练一个epoch
#train_loss = self.train_epoch_stage1(epoch)
train_loss = self.train_epoch(epoch)
train_loss = self.train_epoch_stage1(epoch)
#train_loss = self.train_epoch_stage2(epoch)
#train_loss = self.train_epoch(epoch)
# 验证
'''
@ -476,7 +777,7 @@ class Trainer:
if epoch % self.config.train.save_freq == 0:
self._save_checkpoint(epoch, train_loss)
logger.info(f'Checkpoint saved at epoch {epoch}')
self.model.encoder.unfreeze()
# 训练完成
total_time = time.time() - start_time
@ -555,8 +856,6 @@ class Trainer:
logger.error(f"加载checkpoint失败: {str(e)}")
raise
# ... existing code ...
def _save_octree(self):
"""
保存八叉树到文件
@ -566,6 +865,7 @@ class Trainer:
self.config.train.checkpoint_dir,
self.model_name
)
os.makedirs(checkpoint_dir, exist_ok=True)
octree_path = os.path.join(checkpoint_dir, "octree.pth")
try:

133
brep2sdf/utils/load.py

@ -0,0 +1,133 @@
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import logging
# 假设 logger 是通过 logging 模块配置的
logger = logging.getLogger(__name__)
# utils
def get_namelist(path):
try:
with open(path, 'r') as f:
names = [line.strip() for line in f if line.strip()]
logger.info(f"'{path}' 读取了 {len(names)} 个名称。")
return names
except FileNotFoundError:
logger.error(f"错误: 文件 '{path}' 未找到。")
return
except Exception as e:
logger.error(f"读取文件 '{path}' 时出错: {e}")
return
def get_step_paths(names, step_root_dir, file_extensions, name_filter=None):
"""
根据名称列表文件路径获取所有匹配的 STEP 文件路径
Args:
namelist_path (str): 名称列表文件的路径该文件包含要处理的名称
step_root_dir (str): 步骤文件的根目录每个名称对应一个子目录
file_extensions (list): 要匹配的文件扩展名列表例如 ['.step', '.stp']
name_filter (callable, optional): 文件名过滤函数接受文件名和名称作为参数返回布尔值
Returns:
list: 匹配的 STEP 文件路径列表
"""
# 获取名称列表
if names is None:
logger.error("无法获取名称列表,终止任务。")
return []
step_file_paths = []
skipped_count = 0
# 遍历每个名称,查找匹配的 STEP 文件
for name in names:
step_dir = os.path.join(step_root_dir, name)
if not os.path.isdir(step_dir):
logger.warning(f"目录 '{step_dir}' 不存在。跳过 '{name}'")
skipped_count += 1
continue
step_files = []
try:
# 查找匹配的文件
step_files = [
os.path.join(step_dir, f)
for f in os.listdir(step_dir)
if f.lower().endswith(tuple(file_extensions)) and (not name_filter or name_filter(f, name))
]
except OSError as e:
logger.warning(f"无法访问目录 '{step_dir}': {e}。跳过 '{name}'")
skipped_count += 1
continue
if len(step_files) == 0:
logger.warning(f"在目录 '{step_dir}' 中未找到匹配的文件。跳过 '{name}'")
skipped_count += 1
elif len(step_files) > 1:
logger.warning(f"在目录 '{step_dir}' 中找到多个匹配的文件,将使用第一个: {step_files[0]}")
step_file_paths.append(step_files[0])
else:
step_file_paths.append(step_files[0])
logger.info(f"成功获取 {len(step_file_paths)} 个文件路径,跳过了 {skipped_count} 个名称。")
return step_file_paths
def run_batch_task(task_function, args, common_args_func, file_extensions, name_filter=None):
"""
通用批量任务处理函数
Args:
task_function: 要执行的任务函数接受文件路径脚本路径和通用参数作为参数
args: 命令行参数对象
common_args_func: 生成通用参数的函数
file_extensions: 要匹配的文件扩展名列表
name_filter: 文件名过滤函数可选
Returns:
None
"""
# 获取任务文件路径
tasks = get_step_paths(args.name_list_path, args.step_root_dir, file_extensions, name_filter)
if not tasks:
logger.info("没有找到需要处理的有效文件。")
return
# 准备通用参数
common_args = common_args_func(args)
success_count = 0
failure_count = 0
skipped_count = len(get_namelist(args.name_list_path) or []) - len(tasks)
# 使用 ProcessPoolExecutor 进行并行处理
with ProcessPoolExecutor(max_workers=args.workers) as executor:
# 提交所有任务
futures = {
executor.submit(task_function, task_path, args.train_script, common_args): task_path
for task_path in tasks
}
# 使用 tqdm 显示进度并处理结果
for future in tqdm(as_completed(futures), total=len(tasks), desc="运行任务"):
input_path = futures[future]
try:
input_file, success, stdout, stderr = future.result()
if success:
success_count += 1
# 可以选择记录成功的 stdout/stderr,但通常只记录失败的更有用
# logger.debug(f"成功处理 '{input_file}'. STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
else:
failure_count += 1
logger.error(f"处理 '{input_file}' 失败。STDOUT:\n{stdout}\nSTDERR:\n{stderr}")
except Exception as e:
failure_count += 1
logger.error(f"处理任务 '{input_path}' 时获取结果失败: {e}")
logger.info(f"处理完成。成功: {success_count}, 失败: {failure_count}, 跳过: {skipped_count}")

18
data/scripts/pre_processing1.py

@ -5,6 +5,7 @@ import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
import time
from brep2sdf.utils.load import get_namelist,get_step_paths
from brep2sdf.utils.logger import logger
import numpy as np
@ -335,12 +336,23 @@ def test_single_step(step_path, output_obj_path=None, linear_deflection=0.01):
print(f"\n处理失败: {str(e)}")
return None
def process_for_namelist():
names = get_namelist("/home/wch/brep2sdf/data/name_list.txt")
for name in names:
# 使用 glob 获取匹配的文件列表
step_files = glob.glob(f"/home/wch/brep2sdf/data/step/{name}/{name}*.step")
output = f"/home/wch/brep2sdf/data/gt_mesh/{name}.obj"
test_single_step(step_files[0], output_obj_path=output)
# 准备任务(主进程执行)
if __name__ == "__main__":
main()
#main()
process_for_namelist()
'''
test_single_step(
"/home/wch/brep2sdf/data/step/00002736/00002736_82034c87704b46a891e498d6_step_004.step",
"/home/wch/brep2sdf/data/gt_mesh/00002736.obj"
"/home/wch/brep2sdf/data/step/00000010/00000010_b4b99d35e04b4277931f9a9c_step_000.step",
"/home/wch/brep2sdf/data/gt_mesh/00000031.obj"
)
'''

Loading…
Cancel
Save