You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

230 lines
8.4 KiB

'''
class GridNet:
def __init__(self,
surf_wcs, edge_wcs, surf_ncs, edge_ncs, corner_wcs, corner_unique,
edgeFace_adj, edgeCorner_adj, faceEdge_adj,
surf_bbox, edge_bbox_wcs):
"""
初始化 GridNet
参数:
# 几何数据
'surf_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示面的点云坐标
'edge_wcs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示边的采样点坐标
'surf_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(M, 3)的float32数组,表示归一化后的面点云
'edge_ncs': np.ndarray(dtype=object) # 形状为(N,)的数组,每个元素是形状为(num_edge_sample_points, 3)的float32数组,表示归一化后的边采样点
'corner_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 2, 3)的数组,表示每条边的两个端点坐标
'corner_unique': np.ndarray(dtype=float32) # 形状为(num_vertices, 3)的数组,表示所有顶点的唯一坐标,num_vertices <= num_edges * 2
# 拓扑关系
'edgeFace_adj': np.ndarray(dtype=int32) # 形状为(num_edges, num_faces)的数组,表示边-面邻接关系
'edgeCorner_adj': np.ndarray(dtype=int32) # 形状为(num_edges, 2)的数组,表示边-顶点邻接关系
'faceEdge_adj': np.ndarray(dtype=int32) # 形状为(num_faces, num_edges)的数组,表示面-边邻接关系
# 包围盒数据
'surf_bbox': np.ndarray(dtype=float32) # 形状为(num_faces, 6)的数组,表示每个面的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
'edge_bbox_wcs': np.ndarray(dtype=float32) # 形状为(num_edges, 6)的数组,表示每条边的包围盒[xmin,ymin,zmin,xmax,ymax,zmax]
"""
self.surf_wcs = surf_wcs
self.edge_wcs = edge_wcs
self.surf_ncs = surf_ncs
self.edge_ncs = edge_ncs
self.corner_wcs = corner_wcs
self.corner_unique = corner_unique
self.edgeFace_adj = edgeFace_adj
self.edgeCorner_adj = edgeCorner_adj
self.faceEdge_adj = faceEdge_adj
self.surf_bbox = surf_bbox
self.edge_bbox_wcs = edge_bbox_wcs
# net
self.decoder = # MLP
self.root: OctreeNode = OctreeNode(initial_bbox, max_depth=max_depth)
'''
import torch
import torch.nn as nn
from torch.autograd import grad
from .encoder import Encoder
from .decoder import Decoder
from brep2sdf.utils.logger import logger
class Net(nn.Module):
def __init__(self,
octree,
volume_bboxs,
feature_dim=8,
decoder_output_dim=1,
decoder_hidden_dim=256,
decoder_num_layers=4,
decoder_activation='relu',
decoder_skip_connections=True):
super().__init__()
self.octree_module = octree.to("cpu")
# 初始化 Encoder
self.encoder = Encoder(
feature_dim=feature_dim,
volume_bboxs= volume_bboxs
)
# 初始化 Decoder
self.decoder = Decoder(
d_in=feature_dim,
dims_sdf=[decoder_hidden_dim] * decoder_num_layers,
geometric_init=True,
beta=5
)
#self.csg_combiner = CSGCombiner(flag_convex=True)
def process_sdf(self,f_i, face_indices_mask, operator):
output = f_i[:,0]
# 提取有效值并填充到固定大小 (B, max_patches)
padded_f_i = torch.full((f_i.shape[0], 2), float('inf'), device=f_i.device) # (B, max_patches)
masked_f_i = torch.where(face_indices_mask, f_i, torch.tensor(float('inf'), device=f_i.device)) # 将无效值设置为 inf
# 对每个样本取前 max_patches 个有效值 (B, max_patches)
valid_values, _ = torch.topk(masked_f_i, k=2, dim=1, largest=False) # 提取前两个有效值
# 填充到固定大小 (B, max_patches)
padded_f_i[:, :2] = valid_values # (B, max_patches)
# 找到需要组合的行
mask_concave = (operator == 0)
mask_convex = (operator == 1)
# 对 operator == 0 的样本取最大值
if mask_concave.any():
output[mask_concave] = torch.min(padded_f_i[mask_concave], dim=1).values
# 对 operator == 1 的样本取最小值
if mask_convex.any():
output[mask_convex] = torch.max(padded_f_i[mask_convex], dim=1).values
#logger.gpu_memory_stats("combine后")
return output
@torch.jit.export
def forward(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
#logger.debug("step octree")
_,face_indices_mask,operator = self.octree_module.forward(query_points)
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape)
# 解码
#logger.debug("step decode")
#logger.gpu_memory_stats("encoder farward后")
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_background(self, query_points):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
# 编码
feature_vectors = self.encoder.forward_background(query_points)
# 解码
h = self.decoder.forward_training_volumes(feature_vectors) # (B, D)
return h
@torch.jit.export
def forward_without_octree(self, query_points,face_indices_mask,operator):
"""
前向传播
参数:
query_point: 查询点的位置坐标
返回:
output: 解码后的输出结果
"""
# 批量查询所有点的索引和bbox
#logger.debug("step encode")
# 编码
feature_vectors = self.encoder.forward(query_points,face_indices_mask)
#print("feature_vector:", feature_vectors.shape)
# 解码
f_i = self.decoder(feature_vectors) # (B, P)
#logger.gpu_memory_stats("decoder farward后")
#logger.debug("step combine")
return self.process_sdf(f_i, face_indices_mask, operator)
@torch.jit.export
def forward_training_volumes(self, surf_points, patch_id:int):
"""
only surf sampled points
surf_points (P, S):
return (P, S)
"""
feature_mat = self.encoder.forward_training_volumes(surf_points,patch_id)
f_i = self.decoder.forward_training_volumes(feature_mat)
return f_i.squeeze()
def freeze_stage1(self):
self.encoder.freeze_stage1()
def freeze_stage2(self):
self.encoder.freeze_stage2()
for param in self.decoder.parameters():
param.requires_grad = False
def unfreeze(self):
self.encoder.unfreeze()
for param in self.decoder.parameters():
param.requires_grad = True
def gradient(inputs, outputs):
# 问题点1:inputs可能包含非坐标特征
# 问题点2:未处理batch维度特殊情况
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
# 改进计算方式
points_grad = grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True,
allow_unused=True # 新增异常处理
)[0]
# 修正维度切片方式
if points_grad is None:
return torch.zeros_like(inputs[:, -3:]) # 处理空梯度情况
# 添加安全截取和归一化
coord_grad = points_grad[:, -3:] if points_grad.shape[1] >=3 else points_grad
coord_grad = coord_grad / (coord_grad.norm(dim=-1, keepdim=True) + 1e-6) # 安全归一化
return coord_grad