You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

119 lines
4.7 KiB

'''
class GridNet:
def __init__(self,
surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique,
edgeFace_adj, edgeCorner_adj, faceEdge_adj,
surf_bbox, edge_bbox_wcs):
"""
初始化 GridNet
参数:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
"""
self.surf_wcs = surf_wcs
self.edge_wcs = edge_wcs
self.surf_ncs = surf_ncs
self.edge_ncs = edge_ncs
self.corner_wcs = corner_wcs
self.corner_unique = corner_unique
self.edgeFace_adj = edgeFace_adj
self.edgeCorner_adj = edgeCorner_adj
self.faceEdge_adj = faceEdge_adj
self.surf_bbox = surf_bbox
self.edge_bbox_wcs = edge_bbox_wcs
# net
self.decoder = # MLP
self.root: OctreeNode = OctreeNode(initial_bbox, max_depth=max_depth)
'''
import torch
import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder
from .decoder import Decoder, CSGCombiner
from brep2sdf.utils.logger import logger
class Net(nn.Module):
def __init__(self,
octree,
volume_bboxs,
feature_dim=64,
decoder_input_dim=64,
decoder_output_dim=1,
decoder_hidden_dim=256,
decoder_num_layers=4,
decoder_activation='relu',
decoder_skip_connections=True):
super().__init__()
self.octree_module = octree.to("cpu")
# 初始化 Encoder
self.encoder = Encoder(
feature_dim=feature_dim,
volume_bboxs= volume_bboxs
)
# 初始化 Decoder
self.decoder = Decoder(
d_in=decoder_input_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
geometric_init=True,
beta=100
)
self.csg_combiner = CSGCombiner(flag_convex=True)
def forward(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
_,face_indices,csg_trees = self.octree_module.forward(query_points)
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices)
#print("feature_vector:", feature_vectors.requires_grad)
# 解码
logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors)
logger.gpu_memory_stats("decoder farward后")
output = self.csg_combiner.forward(f_i, csg_trees)
logger.gpu_memory_stats("combine后")
return output
def gradient(inputs, outputs):
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0][:, -3:]
return points_grad