Browse Source
- Added `data_loader.py` with `NHREP_Dataset` class for loading point cloud, feature mask, and CSG tree data - Implemented `CustomDataLoader` for flexible data loading with configurable parameters - Refactored `train.py` to create a structured training pipeline for NHRepNet - Added support for feature sampling, device selection, and TensorBoard logging - Introduced modular training methods with error handling and loggingNH-Rep
2 changed files with 264 additions and 28 deletions
@ -0,0 +1,192 @@ |
|||
import os |
|||
import sys |
|||
import time |
|||
|
|||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) |
|||
sys.path.append(project_dir) |
|||
os.chdir(project_dir) |
|||
|
|||
import numpy as np |
|||
import torch |
|||
from torch.utils.data import Dataset, DataLoader |
|||
from torchvision import transforms |
|||
from pyhocon import ConfigFactory |
|||
from typing import List, Tuple |
|||
from utils.logger import logger |
|||
from utils.general import load_point_cloud_by_file_extension, load_feature_mask |
|||
''' |
|||
一个模型 对应 三个文件 |
|||
*_50k.xyz: 50,000 sampled points of the input B-Rep, can be visualized with MeshLab. |
|||
e.g. x,y,z,nx,ny,nz |
|||
*_50k_mask.txt: (patch_id + 1) of sampled points. |
|||
e.g. 1 or 0 each line |
|||
*_50k_csg.conf: Boolean tree built on the patches, stored in nested lists. 'flag_convex' indicates the convexity of the root node. |
|||
e.g. |
|||
csg{ |
|||
list = [0,1,[2,3,4,],], |
|||
flag_convex = 1, |
|||
} |
|||
''' |
|||
|
|||
|
|||
class NHREP_Dataset(Dataset): |
|||
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): |
|||
""" |
|||
初始化数据集 |
|||
:param data_dir: 数据目录 |
|||
:param name_prefix: 模型名称 |
|||
""" |
|||
self.data_dir = os.path.abspath(data_dir) # 将数据目录转换为绝对路径 |
|||
self.if_baseline = if_baseline |
|||
self.if_feature_sample = if_feature_sample |
|||
self._load_single_data(self.data_dir, name_prefix, if_baseline, if_feature_sample) |
|||
|
|||
def _check_data_file_exists(self, file_name: str): |
|||
if not os.path.exists(file_name): |
|||
logger.error(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") |
|||
raise Exception(f"Data file not found: {file_name}, absolute path: {os.path.abspath(file_name)}") |
|||
|
|||
def _load_feature_samples(self, data_dir: str, file_prefix: str) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
"""加载特征样本""" |
|||
try: |
|||
logger.info(f"Loading feature samples for {file_prefix}") |
|||
# load feature data |
|||
input_fs_file = os.path.join( |
|||
data_dir, |
|||
file_prefix+'_feature.xyz' |
|||
) |
|||
self._check_data_file_exists(input_fs_file) |
|||
feature_data = torch.tensor( |
|||
np.loadtxt(input_fs_file), |
|||
dtype=torch.float32, |
|||
device='cuda' |
|||
) |
|||
|
|||
# load feature mask |
|||
fs_mask_file = os.path.join( |
|||
data_dir, |
|||
file_prefix+'_feature_mask.txt' |
|||
) |
|||
self._check_data_file_exists(fs_mask_file) |
|||
feature_data_mask_pair = torch.tensor( |
|||
np.loadtxt(fs_mask_file), |
|||
dtype=torch.int64, |
|||
device='cuda' |
|||
) |
|||
|
|||
return feature_data, feature_data_mask_pair |
|||
except Exception as e: |
|||
logger.error(f"Error loading feature samples: {str(e)}") |
|||
raise |
|||
|
|||
def _load_single_data(self, data_dir: str, name_prefix: str, if_baseline: bool, if_feature_sample: bool): |
|||
"""从列表加载数据 |
|||
:param data_dir: 数据目录 |
|||
:param name_prefix: 模型名称 |
|||
:param if_baseline: 是否为基准模型 |
|||
:param if_feature_sample: 是否加载特征样本 |
|||
""" |
|||
try: |
|||
logger.info(f"Loading data for {name_prefix}") |
|||
# load xyz file |
|||
# self.data: 2D array of floats, each row represents a point in 3D space |
|||
xyz_file = os.path.join( |
|||
data_dir, |
|||
name_prefix+'.xyz' |
|||
) |
|||
self._check_data_file_exists(xyz_file) |
|||
self.data = load_point_cloud_by_file_extension(xyz_file) |
|||
|
|||
# load mask file |
|||
# self.feature_mask: 1D array of integers, each integer represents a feature mask |
|||
mask_file = os.path.join( |
|||
data_dir, |
|||
name_prefix+'_mask.txt' |
|||
) |
|||
self._check_data_file_exists(mask_file) |
|||
self.feature_mask = load_feature_mask(mask_file) |
|||
|
|||
# load csg file |
|||
# self.csg_tree: list of lists, each inner list represents a node in the CSG tree |
|||
# self.csg_flag_convex: boolean, indicating whether the root node is convex |
|||
try: |
|||
if if_baseline: |
|||
self.csg_tree = [0] |
|||
self.csg_flag_convex = True |
|||
else: |
|||
csg_conf_file = os.path.join( |
|||
data_dir, |
|||
name_prefix+'_csg.conf' |
|||
) |
|||
self._check_data_file_exists(csg_conf_file) |
|||
csg_config = ConfigFactory.parse_file(csg_conf_file) |
|||
self.csg_tree = csg_config.get_list('csg.list') |
|||
self.csg_flag_convex = csg_config.get_int('csg.flag_convex') |
|||
except Exception as e: |
|||
logger.error(f"Error in CSG tree setup: {str(e)}") |
|||
raise |
|||
|
|||
# load feature samples |
|||
# self.feature_data: 2D array of floats, each row represents a point in 3D space |
|||
# self.feature_data_mask_pair: 1D array of integers, each integer represents a feature mask |
|||
if if_feature_sample: |
|||
self.feature_data, self.feature_data_mask_pair = self._load_feature_samples(data_dir, name_prefix) |
|||
|
|||
|
|||
|
|||
except Exception as e: |
|||
logger.error(f"Error loading data from list: {str(e)}") |
|||
raise |
|||
|
|||
def get_data(self): |
|||
return self.data |
|||
|
|||
def get_feature_mask(self): |
|||
return self.feature_mask |
|||
|
|||
def get_csg_tree(self): |
|||
return self.csg_tree, self.csg_flag_convex |
|||
|
|||
def get_feature_data(self): |
|||
if self.if_feature_sample: |
|||
return self.feature_data, self.feature_data_mask_pair |
|||
else: |
|||
return None, None |
|||
|
|||
class CustomDataLoader: |
|||
def __init__(self, data_dir, batch_size=32, shuffle=True, num_workers=4, transform=None): |
|||
""" |
|||
初始化数据加载器 |
|||
:param data_dir: 数据目录 |
|||
:param batch_size: 批量大小 |
|||
:param shuffle: 是否打乱数据 |
|||
:param num_workers: 使用的子进程数 |
|||
:param transform: 数据增强或转换 |
|||
""" |
|||
self.dataset = CustomDataset(data_dir, transform) |
|||
self.dataloader = DataLoader( |
|||
self.dataset, |
|||
batch_size=batch_size, |
|||
shuffle=shuffle, |
|||
num_workers=num_workers |
|||
) |
|||
|
|||
def get_loader(self): |
|||
"""返回数据加载器""" |
|||
return self.dataloader |
|||
|
|||
# 示例用法 |
|||
if __name__ == "__main__": |
|||
# 数据目录和模型名称前缀 |
|||
data_dir = '../data/input_data' # 数据目录 |
|||
name_prefix = 'broken_bullet_50k' |
|||
|
|||
# 数据增强示例 |
|||
transform = transforms.Compose([ |
|||
transforms.Normalize(mean=[0.5], std=[0.5]), # 归一化 |
|||
]) |
|||
|
|||
# 创建数据集实例 |
|||
dataset = NHREP_Dataset(data_dir, name_prefix) |
|||
|
|||
|
@ -1,41 +1,85 @@ |
|||
import os |
|||
import sys |
|||
import time |
|||
|
|||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) |
|||
sys.path.append(project_dir) |
|||
os.chdir(project_dir) |
|||
|
|||
import torch |
|||
import numpy as np |
|||
import os |
|||
from torch.utils.tensorboard import SummaryWriter |
|||
from torch.optim.lr_scheduler import StepLR |
|||
from tqdm import tqdm |
|||
from utils.logger import logger |
|||
from utils.general import gradient |
|||
|
|||
from data_loader import NHREP_Dataset |
|||
from model.network import NHRepNet # 导入 NHRepNet |
|||
|
|||
class NHREPNet_Training: |
|||
def __init__(self, data_dir, name_prefix: str, if_baseline: bool = False, if_feature_sample: bool = False): |
|||
self.dataset = NHREP_Dataset(data_dir, name_prefix, if_baseline, if_feature_sample) |
|||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
# 初始化模型 |
|||
d_in = 6 # 输入维度,例如 3D 坐标 |
|||
dims_sdf = [64, 64, 64] # 隐藏层维度 |
|||
csg_tree, _ = self.dataset.get_csg_tree() |
|||
self.model = NHRepNet(d_in, dims_sdf, csg_tree).to(self.device) # 实例化模型并移动到设备 |
|||
|
|||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) # Adam 优化器 |
|||
self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.1) # 学习率调度器 |
|||
self.nepochs = 15000 # 训练轮数 |
|||
self.writer = SummaryWriter() # TensorBoard 记录器 |
|||
|
|||
def run_nhrepnet_training(self): |
|||
print("running") # 输出训练开始的提示信息 |
|||
self.data = self.data.cuda() # 将数据移动到 GPU 上 |
|||
self.data.requires_grad_() # 设置数据以便计算梯度 |
|||
self.feature_mask = self.feature_mask.cuda() # 将特征掩码移动到 GPU 上 |
|||
n_branch = int(torch.max(self.feature_mask).item()) # 计算分支数量 |
|||
n_batchsize = self.points_batch # 设置批次大小 |
|||
n_patch_batch = n_batchsize // n_branch # 计算每个分支的补丁批次大小 |
|||
|
|||
# 初始化补丁 ID 列表 |
|||
patch_id = [np.where(self.feature_mask.cpu().numpy() == i + 1)[0] for i in range(n_branch)] |
|||
logger.info("开始训练") |
|||
self.model.train() # 设置模型为训练模式 |
|||
|
|||
for epoch in range(self.nepochs): # 开始训练循环 |
|||
try: |
|||
self.train_one_epoch(epoch) |
|||
self.scheduler.step() # 更新学习率 |
|||
except Exception as e: |
|||
logger.error(f"训练过程中发生错误: {str(e)}") |
|||
break |
|||
|
|||
def train_one_epoch(self, epoch): |
|||
logger.info(f"Epoch {epoch}/{self.nepochs} 开始") |
|||
total_loss = 0.0 |
|||
|
|||
# 获取输入数据 |
|||
input_data = self.dataset.get_data().to(self.device) # 获取数据并移动到设备 |
|||
|
|||
# 前向传播 |
|||
outputs = self.model(input_data) # 使用模型进行前向传播 |
|||
|
|||
for epoch in range(15000): # 开始训练循环 |
|||
indices = torch.cat([torch.tensor(patch_id[i][np.random.choice(len(patch_id[i]), n_patch_batch, replace=True)]).cuda() for i in range(n_branch)]).cuda() # 随机选择补丁的索引 |
|||
cur_data = self.data[indices] # 根据索引获取当前数据 |
|||
mnfld_pnts = cur_data[:, :self.d_in] # 提取流形点 |
|||
# 计算损失 |
|||
loss = self.compute_loss(outputs) # 计算损失 |
|||
total_loss += loss.item() |
|||
|
|||
# 反向传播 |
|||
self.optimizer.zero_grad() # 清空梯度 |
|||
loss.backward() # 反向传播 |
|||
self.optimizer.step() # 更新参数 |
|||
|
|||
# 前向传播 |
|||
mnfld_pred_all = self.network(mnfld_pnts) # 进行前向传播,计算流形点的预测值 |
|||
mnfld_pred = mnfld_pred_all[:, 0] # 提取流形预测结果 |
|||
avg_loss = total_loss |
|||
logger.info(f'Epoch [{epoch}/{self.nepochs}], Average Loss: {avg_loss:.4f}') |
|||
self.writer.add_scalar('Loss/train', avg_loss, epoch) # 记录损失到 TensorBoard |
|||
|
|||
# 计算损失 |
|||
loss = (mnfld_pred.abs()).mean() # 计算流形损失 |
|||
def compute_loss(self, outputs): |
|||
""" |
|||
计算流型损失的逻辑 |
|||
|
|||
self.optimizer.zero_grad() # 清零优化器的梯度 |
|||
loss.backward() # 反向传播计算梯度 |
|||
self.optimizer.step() # 更新模型参数 |
|||
:param outputs: 模型的输出 |
|||
:return: 计算得到的流型损失值 |
|||
""" |
|||
|
|||
if epoch % 100 == 0: # 每 100 轮记录损失 |
|||
print(f'Epoch [{epoch}/{self.nepochs}], Loss: {loss.item():.4f}') # 输出当前轮次的损失 |
|||
# 计算流型损失(这里使用均方误差作为示例) |
|||
manifold_loss = (outputs.abs()).mean() # 计算流型损失 |
|||
return manifold_loss |
|||
|
|||
self.tracing() # |
|||
if __name__ == "__main__": |
|||
data_dir = '../data/input_data' # 数据目录 |
|||
name_prefix = 'broken_bullet_50k' |
|||
train = NHREPNet_Training(data_dir, name_prefix, if_baseline=True, if_feature_sample=False) |
|||
train.run_nhrepnet_training() |
Loading…
Reference in new issue