From 9a25fe7e6d84978e1ac95d08fc0d24def99cfcf6 Mon Sep 17 00:00:00 2001 From: mckay Date: Sat, 16 Nov 2024 10:22:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README..md | 44 -- brep2sdf/data/data.py | 20 +- brep2sdf/data/test.py | 121 +++++ brep2sdf/data/utils.py | 1100 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1234 insertions(+), 51 deletions(-) delete mode 100644 README..md create mode 100644 brep2sdf/data/test.py create mode 100644 brep2sdf/data/utils.py diff --git a/README..md b/README..md deleted file mode 100644 index cfa418a..0000000 --- a/README..md +++ /dev/null @@ -1,44 +0,0 @@ -## 做什么事情 -学习 brep 的sdf表示 - -## 意义 -- brep方便控制 -- sdf方便运算和智能生成 - -## 怎么做 -### 数据准备 -数据源: furniture -文件结构 - brep( .step ) - train - chair - chair_1.brep - chair_2.brep - test - val - sdf ( .npz ) - train - chair - chair_1.sdf.npz - chair_2.sdf.npz - -数据结构 - - -### brep to mesh -基本思路:采样,然后三角化 - -- 重采样 - 《Mesh Generation from CAD Models》: sharp edge - - -### mesh to sdf -基本思路:体素,然后积分 - - - - -docker -``` -docker run -it -v /mnt/disk2/dataset:/data -v $(pwd):/app --name brep_sdf ubuntu:22.04 -``` diff --git a/brep2sdf/data/data.py b/brep2sdf/data/data.py index 51ccbfd..51f2782 100644 --- a/brep2sdf/data/data.py +++ b/brep2sdf/data/data.py @@ -4,10 +4,14 @@ from torch.utils.data import Dataset import numpy as np import pickle from brep2sdf.utils.logger import setup_logger - +from .utils import process_brep_data # 设置日志记录器 logger = setup_logger('dataset') + + + + class BRepSDFDataset(Dataset): def __init__(self, brep_dir:str, sdf_dir:str, split:str='train'): """ @@ -82,19 +86,20 @@ class BRepSDFDataset(Dataset): logger.error(f" {key}: type={type(value)}, dtype={value.dtype}, shape={value.shape}") raise + def _load_brep_file(self, brep_path): """加载B-rep特征文件""" try: with open(brep_path, 'rb') as f: brep_data = pickle.load(f) - features = {} + brep_data = {} # 1. 处理几何数据(不等长序列) for key in ['surf_wcs', 'surf_ncs', 'edge_wcs', 'edge_ncs']: if key in brep_data: try: - features[key] = [ + brep_data[key] = [ torch.from_numpy(np.array(x, dtype=np.float32)) for x in brep_data[key] ] @@ -114,7 +119,7 @@ class BRepSDFDataset(Dataset): if key in brep_data: try: data = np.array(brep_data[key], dtype=np.float32) - features[key] = torch.from_numpy(data) + brep_data[key] = torch.from_numpy(data) except Exception as e: logger.error(f"Error converting {key}:") logger.error(f" Type: {type(brep_data[key])}") @@ -128,7 +133,7 @@ class BRepSDFDataset(Dataset): if key in brep_data: try: data = np.array(brep_data[key], dtype=np.int32) - features[key] = torch.from_numpy(data) + brep_data[key] = torch.from_numpy(data) except Exception as e: logger.error(f"Error converting {key}:") logger.error(f" Type: {type(brep_data[key])}") @@ -136,8 +141,9 @@ class BRepSDFDataset(Dataset): logger.error(f" Shape: {brep_data[key].shape}") logger.error(f" dtype: {brep_data[key].dtype}") raise ValueError(f"Failed to convert {key}: {str(e)}") - - return features + feature_embedder = process_brep_data(brep_data, 70,70,1,) + + return feature_embedder except Exception as e: logger.error(f"\nError loading B-rep file: {brep_path}") diff --git a/brep2sdf/data/test.py b/brep2sdf/data/test.py new file mode 100644 index 0000000..a146d68 --- /dev/null +++ b/brep2sdf/data/test.py @@ -0,0 +1,121 @@ +import os +import pickle +import numpy as np +import logging +from datetime import datetime + +# 创建logs目录 +log_dir = 'logs' +os.makedirs(log_dir, exist_ok=True) + +# 创建logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# 生成日志文件名(使用时间戳) +timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') +log_file = os.path.join(log_dir, f'brep_log_{timestamp}.log') + +# 创建文件处理器 +file_handler = logging.FileHandler(log_file, encoding='utf-8') +file_handler.setLevel(logging.INFO) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) + +# 创建格式器 +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +file_handler.setFormatter(formatter) +console_handler.setFormatter(formatter) + +# 添加处理器到logger +logger.addHandler(file_handler) +logger.addHandler(console_handler) + +# 防止日志重复 +logger.propagate = False + +def _load_brep_file(brep_path) -> dict: + """加载并打印B-rep数据结构""" + try: + # 加载B-rep特征 + with open(brep_path, 'rb') as f: + brep_data = pickle.load(f) + + # 打印数据结构 + logger.info(f"\nData structure from {os.path.basename(brep_path)}:") + logger.info(brep_data.keys()) + + # 1. 几何数据 + if 'surf_wcs' in brep_data: + logger.info(f" surf_wcs: list of {len(brep_data['surf_wcs'])} arrays") + for i, surf in enumerate(brep_data['surf_wcs']): + logger.info(f" surface {i}: shape={np.array(surf).shape}, type={type(surf)}") + + if 'surf_ncs' in brep_data: + logger.info(f" surf_ncs: list of {len(brep_data['surf_ncs'])} arrays") + for i, surf in enumerate(brep_data['surf_ncs']): + logger.info(f" surface {i}: shape={np.array(surf).shape}, type={type(surf)}") + + if 'edge_wcs' in brep_data: + logger.info(f" edge_wcs: list of {len(brep_data['edge_wcs'])} arrays") + for i, edge in enumerate(brep_data['edge_wcs']): + logger.info(f" edge {i}: shape={np.array(edge).shape}, type={type(edge)}") + + if 'edge_ncs' in brep_data: + logger.info(f" edge_ncs: list of {len(brep_data['edge_ncs'])} arrays") + for i, edge in enumerate(brep_data['edge_ncs']): + logger.info(f" edge {i}: shape={np.array(edge).shape}, type={type(edge)}") + + # 2. 顶点数据 + if 'corner_wcs' in brep_data: + logger.info(f" corner_wcs: shape={brep_data['corner_wcs'].shape}, type={type(brep_data['corner_wcs'])}") + + if 'corner_unique' in brep_data: + logger.info(f" corner_unique: shape={brep_data['corner_unique'].shape}, type={type(brep_data['corner_unique'])}") + + # 3. 包围盒数据 + if 'surf_bbox_wcs' in brep_data: + logger.info(f" surf_bbox_wcs: shape={brep_data['surf_bbox_wcs'].shape}, type={type(brep_data['surf_bbox_wcs'])}") + + if 'edge_bbox_wcs' in brep_data: + logger.info(f" edge_bbox_wcs: shape={brep_data['edge_bbox_wcs'].shape}, type={type(brep_data['edge_bbox_wcs'])}") + + # 4. 邻接关系数据 + if 'edgeFace_adj' in brep_data: + logger.info(f" edgeFace_adj: shape={brep_data['edgeFace_adj'].shape}, type={type(brep_data['edgeFace_adj'])}") + + if 'edgeCorner_adj' in brep_data: + logger.info(f" edgeCorner_adj: shape={brep_data['edgeCorner_adj'].shape}, type={type(brep_data['edgeCorner_adj'])}") + + if 'faceEdge_adj' in brep_data: + logger.info(f" faceEdge_adj: shape={brep_data['faceEdge_adj'].shape}, type={type(brep_data['faceEdge_adj'])}") + + return brep_data + + except Exception as e: + logger.error(f"Error loading brep from {brep_path}: {str(e)}") + raise + + +def test_load_brep(): + """测试加载B-rep文件的功能""" + try: + + # 测试文件路径(请替换为你实际的测试文件路径) + test_file = "/home/wch/brep2sdf/test_data/pkl/train/chair_1105.pkl" + + + + # 测试加载功能 + loaded_data = _load_brep_file(test_file) + + logger.info("测试成功完成!") + + except Exception as e: + logger.error(f"测试失败: {str(e)}") + raise + +if __name__ == '__main__': + test_load_brep() \ No newline at end of file diff --git a/brep2sdf/data/utils.py b/brep2sdf/data/utils.py new file mode 100644 index 0000000..d9106a0 --- /dev/null +++ b/brep2sdf/data/utils.py @@ -0,0 +1,1100 @@ +import numpy as np +import math +import torch +import torch.nn as nn +import random +import string +import argparse +from chamferdist import ChamferDistance +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +from typing import List, Optional, Tuple, Union + +from OCC.Core.gp import gp_Pnt, gp_Pnt +from OCC.Core.TColgp import TColgp_Array2OfPnt +from OCC.Core.GeomAPI import GeomAPI_PointsToBSplineSurface, GeomAPI_PointsToBSpline +from OCC.Core.GeomAbs import GeomAbs_C2 +from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeWire, BRepBuilderAPI_MakeFace, BRepBuilderAPI_MakeEdge +from OCC.Extend.TopologyUtils import TopologyExplorer, WireExplorer +from OCC.Core.TColgp import TColgp_Array1OfPnt +from OCC.Core.gp import gp_Pnt +from OCC.Core.ShapeFix import ShapeFix_Face, ShapeFix_Wire, ShapeFix_Edge +from OCC.Core.ShapeAnalysis import ShapeAnalysis_Wire +from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Sewing, BRepBuilderAPI_MakeSolid + + +def generate_random_string(length): + characters = string.ascii_letters + string.digits # You can include other characters if needed + random_string = ''.join(random.choice(characters) for _ in range(length)) + return random_string + + +def get_bbox_norm(point_cloud): + # Find the minimum and maximum coordinates along each axis + min_x = np.min(point_cloud[:, 0]) + max_x = np.max(point_cloud[:, 0]) + + min_y = np.min(point_cloud[:, 1]) + max_y = np.max(point_cloud[:, 1]) + + min_z = np.min(point_cloud[:, 2]) + max_z = np.max(point_cloud[:, 2]) + + # Create the 3D bounding box using the min and max values + min_point = np.array([min_x, min_y, min_z]) + max_point = np.array([max_x, max_y, max_z]) + return np.linalg.norm(max_point - min_point) + + +def compute_bbox_center_and_size(min_corner, max_corner): + # Calculate the center + center_x = (min_corner[0] + max_corner[0]) / 2 + center_y = (min_corner[1] + max_corner[1]) / 2 + center_z = (min_corner[2] + max_corner[2]) / 2 + center = np.array([center_x, center_y, center_z]) + # Calculate the size + size_x = max_corner[0] - min_corner[0] + size_y = max_corner[1] - min_corner[1] + size_z = max_corner[2] - min_corner[2] + size = max(size_x, size_y, size_z) + return center, size + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor + will always be created on CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +def pad_repeat(x, max_len): + repeat_times = math.floor(max_len/len(x)) + sep = max_len-repeat_times*len(x) + sep1 = np.repeat(x[:sep], repeat_times+1, axis=0) + sep2 = np.repeat(x[sep:], repeat_times, axis=0) + x_repeat = np.concatenate([sep1, sep2], 0) + return x_repeat + + +def pad_zero(x, max_len, return_mask=False): + keys = np.ones(len(x)) + padding = np.zeros((max_len-len(x))).astype(int) + mask = 1-np.concatenate([keys, padding]) == 1 + padding = np.zeros((max_len-len(x), *x.shape[1:])) + x_padded = np.concatenate([x, padding], axis=0) + if return_mask: + return x_padded, mask + else: + return x_padded + + +def plot_3d_bbox(ax, min_corner, max_corner, color='r'): + """ + Helper function for plotting 3D bounding boxese + """ + vertices = [ + (min_corner[0], min_corner[1], min_corner[2]), + (max_corner[0], min_corner[1], min_corner[2]), + (max_corner[0], max_corner[1], min_corner[2]), + (min_corner[0], max_corner[1], min_corner[2]), + (min_corner[0], min_corner[1], max_corner[2]), + (max_corner[0], min_corner[1], max_corner[2]), + (max_corner[0], max_corner[1], max_corner[2]), + (min_corner[0], max_corner[1], max_corner[2]) + ] + # Define the 12 triangles composing the box + faces = [ + [vertices[0], vertices[1], vertices[2], vertices[3]], + [vertices[4], vertices[5], vertices[6], vertices[7]], + [vertices[0], vertices[1], vertices[5], vertices[4]], + [vertices[2], vertices[3], vertices[7], vertices[6]], + [vertices[1], vertices[2], vertices[6], vertices[5]], + [vertices[4], vertices[7], vertices[3], vertices[0]] + ] + ax.add_collection3d(Poly3DCollection(faces, facecolors='blue', linewidths=1, edgecolors=color, alpha=0)) + return + + +def get_args_vae(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default='data_process/deepcad_parsed', + help='Path to data folder') + parser.add_argument('--train_list', type=str, default='data_process/deepcad_data_split_6bit_surface.pkl', + help='Path to training list') + parser.add_argument('--val_list', type=str, default='data_process/deepcad_data_split_6bit.pkl', + help='Path to validation list') + # Training parameters + parser.add_argument("--option", type=str, choices=['surface', 'edge'], default='surface', + help="Choose between option surface or edge (default: surface)") + parser.add_argument('--batch_size', type=int, default=512, help='input batch size') + parser.add_argument('--train_nepoch', type=int, default=200, help='number of epochs to train for') + parser.add_argument('--save_nepoch', type=int, default=20, help='number of epochs to save model') + parser.add_argument('--test_nepoch', type=int, default=10, help='number of epochs to test model') + parser.add_argument("--data_aug", action='store_true', help='Use data augmentation') + parser.add_argument("--finetune", action='store_true', help='Finetune from existing weights') + parser.add_argument("--weight", type=str, default=None, help='Weight path when finetuning') + parser.add_argument("--gpu", type=int, nargs='+', default=[0], help="GPU IDs to use for training (default: [0])") + # Save dirs and reload + parser.add_argument('--env', type=str, default="surface_vae", help='environment') + parser.add_argument('--dir_name', type=str, default="proj_log", help='name of the log folder.') + args = parser.parse_args() + # saved folder + args.save_dir = f'{args.dir_name}/{args.env}' + return args + + +def get_args_ldm(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default='data_process/deepcad_parsed', + help='Path to data folder') + parser.add_argument('--list', type=str, default='data_process/deepcad_data_split_6bit.pkl', + help='Path to data list') + parser.add_argument('--surfvae', type=str, default='proj_log/deepcad_surfvae/epoch_400.pt', + help='Path to pretrained surface vae weights') + parser.add_argument('--edgevae', type=str, default='proj_log/deepcad_edgevae/epoch_300.pt', + help='Path to pretrained edge vae weights') + parser.add_argument("--option", type=str, choices=['surfpos', 'surfz', 'edgepos', 'edgez'], default='surfpos', + help="Choose between option [surfpos,edgepos,surfz,edgez] (default: surfpos)") + # Training parameters + parser.add_argument('--batch_size', type=int, default=512, help='input batch size') + parser.add_argument('--train_nepoch', type=int, default=3000, help='number of epochs to train for') + parser.add_argument('--test_nepoch', type=int, default=25, help='number of epochs to test model') + parser.add_argument('--save_nepoch', type=int, default=50, help='number of epochs to save model') + parser.add_argument('--max_face', type=int, default=50, help='maximum number of faces') + parser.add_argument('--max_edge', type=int, default=30, help='maximum number of edges per face') + parser.add_argument('--threshold', type=float, default=0.05, help='minimum threshold between two faces') + parser.add_argument('--bbox_scaled', type=float, default=3, help='scaled the bbox') + parser.add_argument('--z_scaled', type=float, default=1, help='scaled the latent z') + parser.add_argument("--gpu", type=int, nargs='+', default=[0, 1], help="GPU IDs to use for training (default: [0, 1])") + parser.add_argument("--data_aug", action='store_true', help='Use data augmentation') + parser.add_argument("--cf", action='store_true', help='Use data augmentation') + # Save dirs and reload + parser.add_argument('--env', type=str, default="surface_pos", help='environment') + parser.add_argument('--dir_name', type=str, default="proj_log", help='name of the log folder.') + args = parser.parse_args() + # saved folder + args.save_dir = f'{args.dir_name}/{args.env}' + return args + + +def rotate_point_cloud(point_cloud, angle_degrees, axis): + """ + Rotate a point cloud around its center by a specified angle in degrees along a specified axis. + + Args: + - point_cloud: Numpy array of shape (N, 3) representing the point cloud. + - angle_degrees: Angle of rotation in degrees. + - axis: Axis of rotation. Can be 'x', 'y', or 'z'. + + Returns: + - rotated_point_cloud: Numpy array of shape (N, 3) representing the rotated point cloud. + """ + + # Convert angle to radians + angle_radians = np.radians(angle_degrees) + + # Compute rotation matrix based on the specified axis + if axis == 'x': + rotation_matrix = np.array([[1, 0, 0], + [0, np.cos(angle_radians), -np.sin(angle_radians)], + [0, np.sin(angle_radians), np.cos(angle_radians)]]) + elif axis == 'y': + rotation_matrix = np.array([[np.cos(angle_radians), 0, np.sin(angle_radians)], + [0, 1, 0], + [-np.sin(angle_radians), 0, np.cos(angle_radians)]]) + elif axis == 'z': + rotation_matrix = np.array([[np.cos(angle_radians), -np.sin(angle_radians), 0], + [np.sin(angle_radians), np.cos(angle_radians), 0], + [0, 0, 1]]) + else: + raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.") + + # Center the point cloud + center = np.mean(point_cloud, axis=0) + centered_point_cloud = point_cloud - center + + # Apply rotation + rotated_point_cloud = np.dot(centered_point_cloud, rotation_matrix.T) + + # Translate back to original position + rotated_point_cloud += center + + # Find the maximum absolute coordinate value + max_abs_coord = np.max(np.abs(rotated_point_cloud)) + + # Scale the point cloud to fit within the -1 to 1 cube + normalized_point_cloud = rotated_point_cloud / max_abs_coord + + return normalized_point_cloud + + +def get_bbox(pnts): + """ + Get the tighest fitting 3D (axis-aligned) bounding box giving a set of points + """ + bbox_corners = [] + for point_cloud in pnts: + # Find the minimum and maximum coordinates along each axis + min_x = np.min(point_cloud[:, 0]) + max_x = np.max(point_cloud[:, 0]) + + min_y = np.min(point_cloud[:, 1]) + max_y = np.max(point_cloud[:, 1]) + + min_z = np.min(point_cloud[:, 2]) + max_z = np.max(point_cloud[:, 2]) + + # Create the 3D bounding box using the min and max values + min_point = np.array([min_x, min_y, min_z]) + max_point = np.array([max_x, max_y, max_z]) + bbox_corners.append([min_point, max_point]) + return np.array(bbox_corners) + + +def bbox_corners(bboxes): + """ + Given the bottom-left and top-right corners of the bbox + Return all eight corners + """ + bboxes_all_corners = [] + for bbox in bboxes: + bottom_left, top_right = bbox[:3], bbox[3:] + # Bottom 4 corners + bottom_front_left = bottom_left + bottom_front_right = (top_right[0], bottom_left[1], bottom_left[2]) + bottom_back_left = (bottom_left[0], top_right[1], bottom_left[2]) + bottom_back_right = (top_right[0], top_right[1], bottom_left[2]) + + # Top 4 corners + top_front_left = (bottom_left[0], bottom_left[1], top_right[2]) + top_front_right = (top_right[0], bottom_left[1], top_right[2]) + top_back_left = (bottom_left[0], top_right[1], top_right[2]) + top_back_right = top_right + + # Combine all coordinates + all_corners = [ + bottom_front_left, + bottom_front_right, + bottom_back_left, + bottom_back_right, + top_front_left, + top_front_right, + top_back_left, + top_back_right, + ] + bboxes_all_corners.append(np.vstack(all_corners)) + bboxes_all_corners = np.array(bboxes_all_corners) + return bboxes_all_corners + + +def rotate_axis(pnts, angle_degrees, axis, normalized=False): + """ + Rotate a point cloud around its center by a specified angle in degrees along a specified axis. + + Args: + - point_cloud: Numpy array of shape (N, ..., 3) representing the point cloud. + - angle_degrees: Angle of rotation in degrees. + - axis: Axis of rotation. Can be 'x', 'y', or 'z'. + + Returns: + - rotated_point_cloud: Numpy array of shape (N, 3) representing the rotated point cloud. + """ + + # Convert angle to radians + angle_radians = np.radians(angle_degrees) + + # Convert points to homogeneous coordinates + shape = list(np.shape(pnts)) + shape[-1] = 1 + pnts_homogeneous = np.concatenate((pnts, np.ones(shape)), axis=-1) + + # Compute rotation matrix based on the specified axis + if axis == 'x': + rotation_matrix = np.array([ + [1, 0, 0, 0], + [0, np.cos(angle_radians), -np.sin(angle_radians), 0], + [0, np.sin(angle_radians), np.cos(angle_radians), 0], + [0, 0, 0, 1] + ]) + elif axis == 'y': + rotation_matrix = np.array([ + [np.cos(angle_radians), 0, np.sin(angle_radians), 0], + [0, 1, 0, 0], + [-np.sin(angle_radians), 0, np.cos(angle_radians), 0], + [0, 0, 0, 1] + ]) + elif axis == 'z': + rotation_matrix = np.array([ + [np.cos(angle_radians), -np.sin(angle_radians), 0, 0], + [np.sin(angle_radians), np.cos(angle_radians), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + else: + raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.") + + # Apply rotation + rotated_pnts_homogeneous = np.dot(pnts_homogeneous, rotation_matrix.T) + rotated_pnts = rotated_pnts_homogeneous[...,:3] + + # Scale the point cloud to fit within the -1 to 1 cube + if normalized: + max_abs_coord = np.max(np.abs(rotated_pnts)) + rotated_pnts = rotated_pnts / max_abs_coord + + return rotated_pnts + + +def rescale_bbox(bboxes, scale): + # Apply scaling factors to bounding boxes + scaled_bboxes = bboxes*scale + return scaled_bboxes + + +def translate_bbox(bboxes): + """ + Randomly move object within the cube (x,y,z direction) + """ + point_cloud = bboxes.reshape(-1,3) + min_x = np.min(point_cloud[:, 0]) + max_x = np.max(point_cloud[:, 0]) + min_y = np.min(point_cloud[:, 1]) + max_y = np.max(point_cloud[:, 1]) + min_z = np.min(point_cloud[:, 2]) + max_z = np.max(point_cloud[:, 2]) + x_offset = np.random.uniform( np.min(-1-min_x,0), np.max(1-max_x,0) ) + y_offset = np.random.uniform( np.min(-1-min_y,0), np.max(1-max_y,0) ) + z_offset = np.random.uniform( np.min(-1-min_z,0), np.max(1-max_z,0) ) + random_translation = np.array([x_offset,y_offset,z_offset]) + bboxes_translated = bboxes + random_translation + return bboxes_translated + + +def edge2loop(face_edges): + face_edges_flatten = face_edges.reshape(-1,3) + # connect end points by closest distance + merged_vertex_id = [] + for edge_idx, startend in enumerate(face_edges): + self_id = [2*edge_idx, 2*edge_idx+1] + # left endpoint + distance = np.linalg.norm(face_edges_flatten - startend[0], axis=1) + min_id = list(np.argsort(distance)) + min_id_noself = [x for x in min_id if x not in self_id] + merged_vertex_id.append(sorted([2*edge_idx, min_id_noself[0]])) + # right endpoint + distance = np.linalg.norm(face_edges_flatten - startend[1], axis=1) + min_id = list(np.argsort(distance)) + min_id_noself = [x for x in min_id if x not in self_id] + merged_vertex_id.append(sorted([2*edge_idx+1, min_id_noself[0]])) + + merged_vertex_id = np.unique(np.array(merged_vertex_id),axis=0) + return merged_vertex_id + + +def keep_largelist(int_lists): + # Initialize a list to store the largest integer lists + largest_int_lists = [] + + # Convert each list to a set for efficient comparison + sets = [set(lst) for lst in int_lists] + + # Iterate through the sets and check if they are subsets of others + for i, s1 in enumerate(sets): + is_subset = False + for j, s2 in enumerate(sets): + if i!=j and s1.issubset(s2) and s1 != s2: + is_subset = True + break + if not is_subset: + largest_int_lists.append(list(s1)) + + # Initialize a set to keep track of seen tuples + seen_tuples = set() + + # Initialize a list to store unique integer lists + unique_int_lists = [] + + # Iterate through the input list + for int_list in largest_int_lists: + # Convert the list to a tuple for hashing + int_tuple = tuple(sorted(int_list)) + + # Check if the tuple is not in the set of seen tuples + if int_tuple not in seen_tuples: + # Add the tuple to the set of seen tuples + seen_tuples.add(int_tuple) + + # Add the original list to the list of unique integer lists + unique_int_lists.append(int_list) + + return unique_int_lists + + +def detect_shared_vertex(edgeV_cad, edge_mask_cad, edgeV_bbox): + """ + Find the shared vertices + """ + edge_id_offset = 2 * np.concatenate([np.array([0]),np.cumsum((edge_mask_cad==False).sum(1))])[:-1] + valid = True + + # Detect shared-vertex on seperate face loop + used_vertex = [] + face_sep_merges = [] + for face_idx, (face_edges, face_edges_mask, bbox_edges) in enumerate(zip(edgeV_cad, edge_mask_cad, edgeV_bbox)): + face_edges = face_edges[~face_edges_mask] + face_edges = face_edges.reshape(len(face_edges),2,3) + face_start_id = edge_id_offset[face_idx] + + # connect end points by closest distance (edge bbox) + merged_vertex_id = edge2loop(bbox_edges) + if len(merged_vertex_id) == len(face_edges): + merged_vertex_id = face_start_id + merged_vertex_id + face_sep_merges.append(merged_vertex_id) + used_vertex.append(bbox_edges*3) + print('[PASS]') + continue + + # connect end points by closest distance (vertex pos) + merged_vertex_id = edge2loop(face_edges) + if len(merged_vertex_id) == len(face_edges): + merged_vertex_id = face_start_id + merged_vertex_id + face_sep_merges.append(merged_vertex_id) + used_vertex.append(face_edges) + print('[PASS]') + continue + + print('[FAILED]') + valid = False + break + + # Invalid + if not valid: + assert False + + # Detect shared-vertex across faces + total_pnts = np.vstack(used_vertex) + total_pnts = total_pnts.reshape(len(total_pnts),2,3) + total_pnts_flatten = total_pnts.reshape(-1,3) + + total_ids = [] + for face_idx, face_merge in enumerate(face_sep_merges): + # non-self merge centers + nonself_face_idx = list(set(np.arange(len(face_sep_merges))) - set([face_idx])) + nonself_face_merges = [face_sep_merges[x] for x in nonself_face_idx] + nonself_face_merges = np.vstack(nonself_face_merges) + nonself_merged_centers = total_pnts_flatten[nonself_face_merges].mean(1) + + # connect end points by closest distance + across_merge_id = [] + for merge_id in face_merge: + merged_center = total_pnts_flatten[merge_id].mean(0) + distance = np.linalg.norm(nonself_merged_centers - merged_center, axis=1) + nonself_match_id = nonself_face_merges[np.argsort(distance)[0]] + joint_merge_id = list(nonself_match_id) + list(merge_id) + across_merge_id.append(joint_merge_id) + total_ids += across_merge_id + + # Merge T-junctions + while (True): + no_merge = True + final_merge_id = [] + + # iteratelly merge until no changes happen + for i in range(len(total_ids)): + perform_merge = False + + for j in range(i+1,len(total_ids)): + # check if vertex can be further merged + max_num = max(len(total_ids[i]), len(total_ids[j])) + union = set(total_ids[i]).union(set(total_ids[j])) + common = set(total_ids[i]).intersection(set(total_ids[j])) + if len(union) > max_num and len(common)>0: + final_merge_id.append(list(union)) + perform_merge = True + no_merge = False + break + + if not perform_merge: + final_merge_id.append(total_ids[i]) # no-merge + + total_ids = final_merge_id + if no_merge: break + + # remove subsets + total_ids = keep_largelist(total_ids) + + # merge again base on absolute coordinate value, required for >3 T-junction + tobe_merged_centers = [total_pnts_flatten[x].mean(0) for x in total_ids] + tobe_centers = np.array(tobe_merged_centers) + distances = np.linalg.norm(tobe_centers[:, np.newaxis, :] - tobe_centers, axis=2) + close_points = distances < 0.1 + mask = np.tril(np.ones_like(close_points, dtype=bool), k=-1) + non_diagonal_indices = np.where(close_points & mask) + row_indices, column_indices = non_diagonal_indices + + # update the total_ids + total_ids_updated = [] + for row, col in zip(row_indices, column_indices): + total_ids_updated.append(total_ids[row] + total_ids[col]) + for index, ids in enumerate(total_ids): + if index not in list(row_indices) and index not in list(column_indices): + total_ids_updated.append(ids) + total_ids = total_ids_updated + + # merged vertices + unique_vertices = [] + for center_id in total_ids: + center_pnts = total_pnts_flatten[center_id].mean(0) / 3.0 + unique_vertices.append(center_pnts) + unique_vertices = np.vstack(unique_vertices) + + new_vertex_dict = {} + for new_id, old_ids in enumerate(total_ids): + new_vertex_dict[new_id] = old_ids + + return [unique_vertices, new_vertex_dict] + + +def detect_shared_edge(unique_vertices, new_vertex_dict, edge_z_cad, surf_z_cad, z_threshold, edge_mask_cad): + """ + Find the shared edges + """ + init_edges = edge_z_cad + + # re-assign edge start/end to unique vertices + new_ids = [] + for old_id in np.arange(2*len(init_edges)): + new_id = [] + for key, value in new_vertex_dict.items(): + # Check if the desired number is in the associated list + if old_id in value: + new_id.append(key) + assert len(new_id) == 1 # should only return one unique value + new_ids.append(new_id[0]) + + EdgeVertexAdj = np.array(new_ids).reshape(-1,2) + + # find edges assigned to the same start/end + similar_edges = [] + for i, s1 in enumerate(EdgeVertexAdj): + for j, s2 in enumerate(EdgeVertexAdj): + if i!=j and set(s1) == set(s2): # same start/end + z1 = init_edges[i] + z2 = init_edges[j] + z_diff = np.abs(z1-z2).mean() + if z_diff < z_threshold: # check z difference + similar_edges.append(sorted([i,j])) + # else: + # print('z latent beyond...') + similar_edges = np.unique(np.array(similar_edges),axis=0) + + # should reduce total edges by two + if not 2*len(similar_edges) == len(EdgeVertexAdj): + assert False, 'edge not reduced by 2' + + # unique edges + unique_edge_id = similar_edges[:,0] + EdgeVertexAdj = EdgeVertexAdj[unique_edge_id] + unique_edges = init_edges[unique_edge_id] + + # unique faces + unique_faces = surf_z_cad + FaceEdgeAdj = [] + ranges = np.concatenate([np.array([0]),np.cumsum((edge_mask_cad==False).sum(1))]) + for index in range(len(ranges)-1): + adj_ids = np.arange(ranges[index], ranges[index+1]) + new_ids = [] + for id in adj_ids: + new_id = np.where(similar_edges == id)[0] + assert len(new_id) == 1 + new_ids.append(new_id[0]) + FaceEdgeAdj.append(new_ids) + + print(f'Post-process: F-{len(unique_faces)} E-{len(unique_edges)} V-{len(unique_vertices)}') + + return [unique_faces, unique_edges, FaceEdgeAdj, EdgeVertexAdj] + + +class STModel(nn.Module): + def __init__(self, num_edge, num_surf): + super().__init__() + self.edge_t = nn.Parameter(torch.zeros((num_edge, 3))) + self.surf_st = nn.Parameter(torch.FloatTensor([1,0,0,0]).unsqueeze(0).repeat(num_surf,1)) + + +def get_bbox_minmax(point_cloud): + # Find the minimum and maximum coordinates along each axis + min_x = np.min(point_cloud[:, 0]) + max_x = np.max(point_cloud[:, 0]) + + min_y = np.min(point_cloud[:, 1]) + max_y = np.max(point_cloud[:, 1]) + + min_z = np.min(point_cloud[:, 2]) + max_z = np.max(point_cloud[:, 2]) + + # Create the 3D bounding box using the min and max values + min_point = np.array([min_x, min_y, min_z]) + max_point = np.array([max_x, max_y, max_z]) + return (min_point, max_point) + + +def joint_optimize(surf_ncs, edge_ncs, surfPos, unique_vertices, EdgeVertexAdj, FaceEdgeAdj, num_edge, num_surf): + """ + Jointly optimize the face/edge/vertex based on topology + """ + loss_func = ChamferDistance() + + model = STModel(num_edge, num_surf) + model = model.cuda().train() + optimizer = torch.optim.AdamW( + model.parameters(), + lr=1e-3, + betas=(0.95, 0.999), + weight_decay=1e-6, + eps=1e-08, + ) + + # Optimize edges (directly compute) + edge_ncs_se = edge_ncs[:,[0,-1]] + edge_vertex_se = unique_vertices[EdgeVertexAdj] + + edge_wcs = [] + print('Joint Optimization...') + for wcs, ncs_se, vertex_se in zip(edge_ncs, edge_ncs_se, edge_vertex_se): + # scale + scale_target = np.linalg.norm(vertex_se[0] - vertex_se[1]) + scale_ncs = np.linalg.norm(ncs_se[0] - ncs_se[1]) + edge_scale = scale_target / scale_ncs + + edge_updated = wcs*edge_scale + edge_se = ncs_se*edge_scale + + # offset + offset = (vertex_se - edge_se) + offset_rev = (vertex_se - edge_se[::-1]) + + # swap start / end if necessary + offset_error = np.abs(offset[0] - offset[1]).mean() + offset_rev_error =np.abs(offset_rev[0] - offset_rev[1]).mean() + if offset_rev_error < offset_error: + edge_updated = edge_updated[::-1] + offset = offset_rev + + edge_updated = edge_updated + offset.mean(0)[np.newaxis,np.newaxis,:] + edge_wcs.append(edge_updated) + + edge_wcs = np.vstack(edge_wcs) + + # Replace start/end points with corner, and backprop change along curve + for index in range(len(edge_wcs)): + start_vec = edge_vertex_se[index,0] - edge_wcs[index, 0] + end_vec = edge_vertex_se[index,1] - edge_wcs[index, -1] + weight = np.tile((np.arange(32)/31)[:,np.newaxis], (1,3)) + weighted_vec = np.tile(start_vec[np.newaxis,:],(32,1))*(1-weight) + np.tile(end_vec,(32,1))*weight + edge_wcs[index] += weighted_vec + + # Optimize surfaces + face_edges = [] + for adj in FaceEdgeAdj: + all_pnts = edge_wcs[adj] + face_edges.append(torch.FloatTensor(all_pnts).cuda()) + + # Initialize surface in wcs based on surface pos + surf_wcs_init = [] + bbox_threshold_min = [] + bbox_threshold_max = [] + for edges_perface, ncs, bbox in zip(face_edges, surf_ncs, surfPos): + surf_center, surf_scale = compute_bbox_center_and_size(bbox[0:3], bbox[3:]) + edges_perface_flat = edges_perface.reshape(-1, 3).detach().cpu().numpy() + min_point, max_point = get_bbox_minmax(edges_perface_flat) + edge_center, edge_scale = compute_bbox_center_and_size(min_point, max_point) + bbox_threshold_min.append(min_point) + bbox_threshold_max.append(max_point) + + # increase surface size if does not fully cover the wire bbox + if surf_scale < edge_scale: + surf_scale = 1.05*edge_scale + + wcs = ncs * (surf_scale/2) + surf_center + surf_wcs_init.append(wcs) + + surf_wcs_init = np.stack(surf_wcs_init) + + # optimize the surface offset + surf = torch.FloatTensor(surf_wcs_init).cuda() + for iters in range(200): + surf_scale = model.surf_st[:,0].reshape(-1,1,1,1) + surf_offset = model.surf_st[:,1:].reshape(-1,1,1,3) + surf_updated = surf + surf_offset + + surf_loss = 0 + for surf_pnt, edge_pnts in zip(surf_updated, face_edges): + surf_pnt = surf_pnt.reshape(-1,3) + edge_pnts = edge_pnts.reshape(-1,3).detach() + surf_loss += loss_func(surf_pnt.unsqueeze(0), edge_pnts.unsqueeze(0), bidirectional=False, reverse=True) + surf_loss /= len(surf_updated) + + optimizer.zero_grad() + (surf_loss).backward() + optimizer.step() + + # print(f'Iter {iters} surf:{surf_loss:.5f}') + + surf_wcs = surf_updated.detach().cpu().numpy() + + return (surf_wcs, edge_wcs) + + +def add_pcurves_to_edges(face): + edge_fixer = ShapeFix_Edge() + top_exp = TopologyExplorer(face) + for wire in top_exp.wires(): + wire_exp = WireExplorer(wire) + for edge in wire_exp.ordered_edges(): + edge_fixer.FixAddPCurve(edge, face, False, 0.001) + + +def fix_wires(face, debug=False): + top_exp = TopologyExplorer(face) + for wire in top_exp.wires(): + if debug: + wire_checker = ShapeAnalysis_Wire(wire, face, 0.01) + print(f"Check order 3d {wire_checker.CheckOrder()}") + print(f"Check 3d gaps {wire_checker.CheckGaps3d()}") + print(f"Check closed {wire_checker.CheckClosed()}") + print(f"Check connected {wire_checker.CheckConnected()}") + wire_fixer = ShapeFix_Wire(wire, face, 0.01) + + # wire_fixer.SetClosedWireMode(True) + # wire_fixer.SetFixConnectedMode(True) + # wire_fixer.SetFixSeamMode(True) + + assert wire_fixer.IsReady() + ok = wire_fixer.Perform() + # assert ok + + +def fix_face(face): + fixer = ShapeFix_Face(face) + fixer.SetPrecision(0.01) + fixer.SetMaxTolerance(0.1) + ok = fixer.Perform() + # assert ok + fixer.FixOrientation() + face = fixer.Face() + return face + + +def construct_brep(surf_wcs, edge_wcs, FaceEdgeAdj, EdgeVertexAdj): + """ + Fit parametric surfaces / curves and trim into B-rep + """ + print('Building the B-rep...') + # Fit surface bspline + recon_faces = [] + for points in surf_wcs: + num_u_points, num_v_points = 32, 32 + uv_points_array = TColgp_Array2OfPnt(1, num_u_points, 1, num_v_points) + for u_index in range(1,num_u_points+1): + for v_index in range(1,num_v_points+1): + pt = points[u_index-1, v_index-1] + point_3d = gp_Pnt(float(pt[0]), float(pt[1]), float(pt[2])) + uv_points_array.SetValue(u_index, v_index, point_3d) + approx_face = GeomAPI_PointsToBSplineSurface(uv_points_array, 3, 8, GeomAbs_C2, 5e-2).Surface() + recon_faces.append(approx_face) + + recon_edges = [] + for points in edge_wcs: + num_u_points = 32 + u_points_array = TColgp_Array1OfPnt(1, num_u_points) + for u_index in range(1,num_u_points+1): + pt = points[u_index-1] + point_2d = gp_Pnt(float(pt[0]), float(pt[1]), float(pt[2])) + u_points_array.SetValue(u_index, point_2d) + try: + approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 5e-3).Curve() + except Exception as e: + print('high precision failed, trying mid precision...') + try: + approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 8e-3).Curve() + except Exception as e: + print('mid precision failed, trying low precision...') + approx_edge = GeomAPI_PointsToBSpline(u_points_array, 0, 8, GeomAbs_C2, 5e-2).Curve() + recon_edges.append(approx_edge) + + # Create edges from the curve list + edge_list = [] + for curve in recon_edges: + edge = BRepBuilderAPI_MakeEdge(curve).Edge() + edge_list.append(edge) + + # Cut surface by wire + post_faces = [] + post_edges = [] + for idx,(surface, edge_incides) in enumerate(zip(recon_faces, FaceEdgeAdj)): + corner_indices = EdgeVertexAdj[edge_incides] + + # ordered loop + loops = [] + ordered = [0] + seen_corners = [corner_indices[0,0], corner_indices[0,1]] + next_index = corner_indices[0,1] + + while len(ordered) Tuple[torch.Tensor, ...]: + """ + 处理B-rep数据的函数 + + Args: + data (dict): 包含B-rep数据的字典,结构如下: + { + 'surf_ncs': np.ndarray, # 面归一化点云 [num_faces, 100, 3] + 'edge_ncs': np.ndarray, # 边归一化点云 [num_edges, 10, 3] + 'corner_wcs': np.ndarray, # 顶点坐标 [num_edges, 2, 3] + 'faceEdge_adj': np.ndarray, # 面-边邻接矩阵 [num_faces, num_edges] + 'surf_pos': np.ndarray, # 面位置 [num_faces, 6] + 'edge_pos': np.ndarray, # 边位置 [num_edges, 6] + } + max_face (int): 最大面数 + max_edge (int): 最大边数 + bbox_scaled (float): 边界框缩放因子 + aug (bool): 是否使用数据增强 + data_class (Optional[int]): 数据类别标签 + + Returns: + Tuple[torch.Tensor, ...]: 包含以下张量的元组: + - edge_ncs: 边归一化特征 [num_faces, max_edge, 10, 3] + - edge_pos: 边位置 [num_faces, max_edge, 6] + - edge_mask: 边掩码 [num_faces, max_edge] + - surf_ncs: 面归一化特征 [max_face, 100, 3] + - surf_pos: 面位置 [max_face, 6] + - vertex_pos: 顶点位置 [max_face, max_edge, 6] + - data_class: (可选) 类别标签 [1] + """ + # 解包数据 + _, _, surf_ncs, edge_ncs, corner_wcs, _, _, faceEdge_adj, surf_pos, edge_pos, _, _ = data.values() + + # 数据增强 + random_num = np.random.rand() + if random_num > 0.5 and aug: + # 获取边界框八个角点 + surfpos_corners = bbox_corners(surf_pos) # [num_faces, 8, 3] + edgepos_corners = bbox_corners(edge_pos) # [num_edges, 8, 3] + + # 随机旋转 + for axis in ['x', 'y', 'z']: + angle = random.choice([90, 180, 270]) + # 旋转所有几何元素,保持形状不变 + surfpos_corners = rotate_axis(surfpos_corners, angle, axis, normalized=True) + edgepos_corners = rotate_axis(edgepos_corners, angle, axis, normalized=True) + corner_wcs = rotate_axis(corner_wcs, angle, axis, normalized=True) + surf_ncs = rotate_axis(surf_ncs, angle, axis, normalized=False) + edge_ncs = rotate_axis(edge_ncs, angle, axis, normalized=False) + + # 重新计算边界框 + surf_pos = get_bbox(surfpos_corners) # [num_faces, 2, 3] + surf_pos = surf_pos.reshape(len(surf_pos), 6) # [num_faces, 6] + edge_pos = get_bbox(edgepos_corners) # [num_edges, 2, 3] + edge_pos = edge_pos.reshape(len(edge_pos), 6) # [num_edges, 6] + + # 缩放值范围 + surf_pos = surf_pos * bbox_scaled # [num_faces, 6] + edge_pos = edge_pos * bbox_scaled # [num_edges, 6] + corner_wcs = corner_wcs * bbox_scaled # [num_edges, 2, 3] + + # 特征复制 + edge_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 6] + vertex_pos_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 6] + edge_ncs_duplicated: List[np.ndarray] = [] # 每个元素形状: [num_edges_per_face, 10, 3] + + for adj in faceEdge_adj: # [num_faces, num_edges] + edge_ncs_duplicated.append(edge_ncs[adj]) # [num_edges_per_face, 10, 3] + edge_pos_duplicated.append(edge_pos[adj]) # [num_edges_per_face, 6] + corners = corner_wcs[adj] # [num_edges_per_face, 2, 3] + corners_sorted = [] + for corner in corners: # [2, 3] + sorted_indices = np.lexsort((corner[:, 2], corner[:, 1], corner[:, 0])) + corners_sorted.append(corner[sorted_indices].flatten()) # [6] + corners = np.stack(corners_sorted) # [num_edges_per_face, 6] + vertex_pos_duplicated.append(corners) + + # 边特征打乱和填充 + edge_pos_new = [] # 最终形状: [num_faces, max_edge, 6] + edge_ncs_new = [] # 最终形状: [num_faces, max_edge, 10, 3] + vert_pos_new = [] # 最终形状: [num_faces, max_edge, 6] + edge_mask = [] # 最终形状: [num_faces, max_edge] + + for pos, ncs, vert in zip(edge_pos_duplicated, edge_ncs_duplicated, vertex_pos_duplicated): + random_indices = np.random.permutation(pos.shape[0]) + pos = pos[random_indices] # [num_edges_per_face, 6] + ncs = ncs[random_indices] # [num_edges_per_face, 10, 3] + vert = vert[random_indices] # [num_edges_per_face, 6] + + pos, mask = pad_zero(pos, max_edge, return_mask=True) # [max_edge, 6], [max_edge] + ncs = pad_zero(ncs, max_edge) # [max_edge, 10, 3] + vert = pad_zero(vert, max_edge) # [max_edge, 6] + + edge_pos_new.append(pos) + edge_ncs_new.append(ncs) + edge_mask.append(mask) + vert_pos_new.append(vert) + + edge_pos = np.stack(edge_pos_new) # [num_faces, max_edge, 6] + edge_ncs = np.stack(edge_ncs_new) # [num_faces, max_edge, 10, 3] + edge_mask = np.stack(edge_mask) # [num_faces, max_edge] + vertex_pos = np.stack(vert_pos_new) # [num_faces, max_edge, 6] + + # 面特征打乱 + random_indices = np.random.permutation(surf_pos.shape[0]) + surf_pos = surf_pos[random_indices] # [num_faces, 6] + edge_pos = edge_pos[random_indices] # [num_faces, max_edge, 6] + surf_ncs = surf_ncs[random_indices] # [num_faces, 100, 3] + edge_ncs = edge_ncs[random_indices] # [num_faces, max_edge, 10, 3] + edge_mask = edge_mask[random_indices] # [num_faces, max_edge] + vertex_pos = vertex_pos[random_indices] # [num_faces, max_edge, 6] + + # 填充到最大面数 + surf_pos = pad_zero(surf_pos, max_face) # [max_face, 6] + surf_ncs = pad_zero(surf_ncs, max_face) # [max_face, 100, 3] + edge_pos = pad_zero(edge_pos, max_face) # [max_face, max_edge, 6] + edge_ncs = pad_zero(edge_ncs, max_face) # [max_face, max_edge, 10, 3] + vertex_pos = pad_zero(vertex_pos, max_face) # [max_face, max_edge, 6] + + # 扩展边掩码 + padding = np.zeros((max_face-len(edge_mask), *edge_mask.shape[1:])) == 0 # [max_face-num_faces, max_edge] + edge_mask = np.concatenate([edge_mask, padding], 0) # [max_face, max_edge] + + # 转换为张量并返回 + if data_class is not None: + return ( + torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3] + torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] + torch.BoolTensor(edge_mask), # [max_face, max_edge] + torch.FloatTensor(surf_ncs), # [max_face, 100, 3] + torch.FloatTensor(surf_pos), # [max_face, 6] + torch.FloatTensor(vertex_pos), # [max_face, max_edge, 6] + torch.LongTensor([data_class+1]) # [1] + ) + else: + return ( + torch.FloatTensor(edge_ncs), # [max_face, max_edge, 10, 3] + torch.FloatTensor(edge_pos), # [max_face, max_edge, 6] + torch.BoolTensor(edge_mask), # [max_face, max_edge] + torch.FloatTensor(surf_ncs), # [max_face, 100, 3] + torch.FloatTensor(surf_pos), # [max_face, 6] + torch.FloatTensor(vertex_pos) # [max_face, max_edge, 6] + ) \ No newline at end of file