You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.7 KiB
119 lines
4.7 KiB
|
|
|
|
|
|
'''
|
|
class GridNet:
|
|
def __init__(self,
|
|
surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique,
|
|
edgeFace_adj, edgeCorner_adj, faceEdge_adj,
|
|
surf_bbox, edge_bbox_wcs):
|
|
"""
|
|
初始化 GridNet
|
|
|
|
参数:
|
|
# 几何数据
|
|
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
|
|
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
|
|
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
|
|
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
|
|
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
|
|
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
|
|
|
|
# 拓扑关系
|
|
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
|
|
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
|
|
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
|
|
|
|
# 包围盒数据
|
|
'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
"""
|
|
self.surf_wcs = surf_wcs
|
|
self.edge_wcs = edge_wcs
|
|
self.surf_ncs = surf_ncs
|
|
self.edge_ncs = edge_ncs
|
|
self.corner_wcs = corner_wcs
|
|
self.corner_unique = corner_unique
|
|
self.edgeFace_adj = edgeFace_adj
|
|
self.edgeCorner_adj = edgeCorner_adj
|
|
self.faceEdge_adj = faceEdge_adj
|
|
self.surf_bbox = surf_bbox
|
|
self.edge_bbox_wcs = edge_bbox_wcs
|
|
|
|
# net
|
|
self.decoder = # MLP
|
|
self.root: OctreeNode = OctreeNode(initial_bbox, max_depth=max_depth)
|
|
'''
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import grad
|
|
from .encoder import Encoder
|
|
from .decoder import Decoder, CSGCombiner
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self,
|
|
octree,
|
|
volume_bboxs,
|
|
feature_dim=64,
|
|
decoder_input_dim=64,
|
|
decoder_output_dim=1,
|
|
decoder_hidden_dim=256,
|
|
decoder_num_layers=4,
|
|
decoder_activation='relu',
|
|
decoder_skip_connections=True):
|
|
|
|
super().__init__()
|
|
|
|
self.octree_module = octree.to("cpu")
|
|
|
|
# 初始化 Encoder
|
|
self.encoder = Encoder(
|
|
feature_dim=feature_dim,
|
|
volume_bboxs= volume_bboxs
|
|
)
|
|
|
|
# 初始化 Decoder
|
|
self.decoder = Decoder(
|
|
d_in=decoder_input_dim,
|
|
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
|
|
geometric_init=True,
|
|
beta=100
|
|
)
|
|
|
|
self.csg_combiner = CSGCombiner(flag_convex=True)
|
|
|
|
def forward(self, query_points):
|
|
"""
|
|
前向传播
|
|
|
|
参数:
|
|
query_point: 查询点的位置坐标
|
|
返回:
|
|
output: 解码后的输出结果
|
|
"""
|
|
# 批量查询所有点的索引和bbox
|
|
_,face_indices,csg_trees = self.octree_module.forward(query_points)
|
|
# 编码
|
|
feature_vectors = self.encoder.forward(query_points,face_indices)
|
|
#print("feature_vector:", feature_vectors.requires_grad)
|
|
# 解码
|
|
logger.gpu_memory_stats("encoder farward后")
|
|
f_i = self.decoder(feature_vectors)
|
|
logger.gpu_memory_stats("decoder farward后")
|
|
output = self.csg_combiner.forward(f_i, csg_trees)
|
|
logger.gpu_memory_stats("combine后")
|
|
return output
|
|
|
|
|
|
def gradient(inputs, outputs):
|
|
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
|
|
points_grad = grad(
|
|
outputs=outputs,
|
|
inputs=inputs,
|
|
grad_outputs=d_points,
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
only_inputs=True)[0][:, -3:]
|
|
return points_grad
|