You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
230 lines
8.4 KiB
230 lines
8.4 KiB
|
|
|
|
|
|
'''
|
|
class GridNet:
|
|
def __init__(self,
|
|
surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique,
|
|
edgeFace_adj, edgeCorner_adj, faceEdge_adj,
|
|
surf_bbox, edge_bbox_wcs):
|
|
"""
|
|
初始化 GridNet
|
|
|
|
参数:
|
|
# 几何数据
|
|
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
|
|
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
|
|
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
|
|
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
|
|
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
|
|
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
|
|
|
|
# 拓扑关系
|
|
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
|
|
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
|
|
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
|
|
|
|
# 包围盒数据
|
|
'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
|
|
"""
|
|
self.surf_wcs = surf_wcs
|
|
self.edge_wcs = edge_wcs
|
|
self.surf_ncs = surf_ncs
|
|
self.edge_ncs = edge_ncs
|
|
self.corner_wcs = corner_wcs
|
|
self.corner_unique = corner_unique
|
|
self.edgeFace_adj = edgeFace_adj
|
|
self.edgeCorner_adj = edgeCorner_adj
|
|
self.faceEdge_adj = faceEdge_adj
|
|
self.surf_bbox = surf_bbox
|
|
self.edge_bbox_wcs = edge_bbox_wcs
|
|
|
|
# net
|
|
self.decoder = # MLP
|
|
self.root: OctreeNode = OctreeNode(initial_bbox, max_depth=max_depth)
|
|
'''
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import grad
|
|
from .encoder import Encoder
|
|
from .decoder import Decoder
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self,
|
|
octree,
|
|
volume_bboxs,
|
|
feature_dim=8,
|
|
decoder_output_dim=1,
|
|
decoder_hidden_dim=256,
|
|
decoder_num_layers=4,
|
|
decoder_activation='relu',
|
|
decoder_skip_connections=True):
|
|
|
|
super().__init__()
|
|
|
|
self.octree_module = octree.to("cpu")
|
|
|
|
# 初始化 Encoder
|
|
self.encoder = Encoder(
|
|
feature_dim=feature_dim,
|
|
volume_bboxs= volume_bboxs
|
|
)
|
|
|
|
# 初始化 Decoder
|
|
self.decoder = Decoder(
|
|
d_in=feature_dim,
|
|
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
|
|
geometric_init=True,
|
|
beta=5
|
|
)
|
|
|
|
#self.csg_combiner = CSGCombiner(flag_convex=True)
|
|
|
|
def process_sdf(self,f_i, face_indices_mask, operator):
|
|
output = f_i[:,0]
|
|
# 提取有效值并填充到固定大小 (B, max_patches)
|
|
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
|
|
masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
|
|
|
|
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
|
|
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
|
|
|
|
# 填充到固定大小 (B, max_patches)
|
|
padded_f_i[:, :2] = valid_values # (B, max_patches)
|
|
|
|
# 找到需要组合的行
|
|
mask_concave = (operator == 0)
|
|
mask_convex = (operator == 1)
|
|
|
|
# 对 operator == 0 的样本取最大值
|
|
if mask_concave.any():
|
|
output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values
|
|
|
|
# 对 operator == 1 的样本取最小值
|
|
if mask_convex.any():
|
|
output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values
|
|
|
|
#logger.gpu_memory_stats("combine后")
|
|
return output
|
|
|
|
@torch.jit.export
|
|
def forward(self, query_points):
|
|
"""
|
|
前向传播
|
|
|
|
参数:
|
|
query_point: 查询点的位置坐标
|
|
返回:
|
|
output: 解码后的输出结果
|
|
"""
|
|
# 批量查询所有点的索引和bbox
|
|
#logger.debug("step octree")
|
|
_,face_indices_mask,operator = self.octree_module.forward(query_points)
|
|
#logger.debug("step encode")
|
|
# 编码
|
|
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
|
|
#print("feature_vector:", feature_vectors.shape)
|
|
# 解码
|
|
#logger.debug("step decode")
|
|
#logger.gpu_memory_stats("encoder farward后")
|
|
f_i = self.decoder(feature_vectors) # (B, P)
|
|
#logger.gpu_memory_stats("decoder farward后")
|
|
|
|
#logger.debug("step combine")
|
|
return self.process_sdf(f_i, face_indices_mask, operator)
|
|
|
|
@torch.jit.export
|
|
def forward_background(self, query_points):
|
|
"""
|
|
前向传播
|
|
|
|
参数:
|
|
query_point: 查询点的位置坐标
|
|
返回:
|
|
output: 解码后的输出结果
|
|
"""
|
|
# 批量查询所有点的索引和bbox
|
|
# 编码
|
|
feature_vectors = self.encoder.forward_background(query_points)
|
|
# 解码
|
|
h = self.decoder.forward_training_volumes(feature_vectors) # (B, D)
|
|
return h
|
|
|
|
@torch.jit.export
|
|
def forward_without_octree(self, query_points,face_indices_mask,operator):
|
|
"""
|
|
前向传播
|
|
|
|
参数:
|
|
query_point: 查询点的位置坐标
|
|
返回:
|
|
output: 解码后的输出结果
|
|
"""
|
|
# 批量查询所有点的索引和bbox
|
|
#logger.debug("step encode")
|
|
# 编码
|
|
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
|
|
#print("feature_vector:", feature_vectors.shape)
|
|
# 解码
|
|
f_i = self.decoder(feature_vectors) # (B, P)
|
|
#logger.gpu_memory_stats("decoder farward后")
|
|
|
|
#logger.debug("step combine")
|
|
return self.process_sdf(f_i, face_indices_mask, operator)
|
|
|
|
@torch.jit.export
|
|
def forward_training_volumes(self, surf_points, patch_id:int):
|
|
"""
|
|
only surf sampled points
|
|
surf_points (P, S):
|
|
return (P, S)
|
|
"""
|
|
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
|
|
f_i = self.decoder.forward_training_volumes(feature_mat)
|
|
return f_i.squeeze()
|
|
|
|
|
|
def freeze_stage1(self):
|
|
self.encoder.freeze_stage1()
|
|
|
|
def freeze_stage2(self):
|
|
self.encoder.freeze_stage2()
|
|
for param in self.decoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
def unfreeze(self):
|
|
self.encoder.unfreeze()
|
|
for param in self.decoder.parameters():
|
|
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
def gradient(inputs, outputs):
|
|
# 问题点1:inputs可能包含非坐标特征
|
|
# 问题点2:未处理batch维度特殊情况
|
|
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
|
|
|
|
# 改进计算方式
|
|
points_grad = grad(
|
|
outputs=outputs,
|
|
inputs=inputs,
|
|
grad_outputs=d_points,
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
only_inputs=True,
|
|
allow_unused=True # 新增异常处理
|
|
)[0]
|
|
|
|
# 修正维度切片方式
|
|
if points_grad is None:
|
|
return torch.zeros_like(inputs[:, -3:]) # 处理空梯度情况
|
|
|
|
# 添加安全截取和归一化
|
|
coord_grad = points_grad[:, -3:] if points_grad.shape[1] >=3 else points_grad
|
|
coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-6) # 安全归一化
|
|
|
|
return coord_grad
|