11 changed files with 780 additions and 310 deletions
@ -1,42 +1,219 @@ |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
from torch import Tensor |
|||
import torch |
|||
import numpy as np |
|||
from typing import Tuple, List, Union |
|||
from brep2sdf.utils.logger import logger |
|||
|
|||
class Sine(nn.Module): |
|||
def __init(self): |
|||
super().__init__() |
|||
|
|||
def forward(self, input): |
|||
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 |
|||
return torch.sin(30 * input) |
|||
|
|||
class Decoder(nn.Module): |
|||
def __init__(self, |
|||
input_dim: int, |
|||
output_dim: int, |
|||
hidden_dim: int = 256) : |
|||
""" |
|||
最简单的Decoder实现 |
|||
|
|||
参数: |
|||
input_dim: 输入维度 |
|||
output_dim: 输出维度 |
|||
hidden_dim: 隐藏层维度 (默认: 256) |
|||
""" |
|||
def __init__( |
|||
self, |
|||
d_in: int, |
|||
dims_sdf: List[int], |
|||
skip_in: Tuple[int, ...] = (), |
|||
flag_convex: bool = True, |
|||
geometric_init: bool = True, |
|||
radius_init: float = 1, |
|||
beta: float = 100, |
|||
) -> None: |
|||
super().__init__() |
|||
|
|||
# 三层全连接网络 |
|||
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|||
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|||
self.fc3 = nn.Linear(hidden_dim, output_dim) |
|||
|
|||
self.flag_convex = flag_convex |
|||
self.skip_in = skip_in |
|||
|
|||
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] |
|||
self.sdf_layers = len(dims_sdf) |
|||
for layer in range(0, len(dims_sdf) - 1): |
|||
if layer + 1 in skip_in: |
|||
out_dim = dims_sdf[layer + 1] - d_in |
|||
else: |
|||
out_dim = dims_sdf[layer + 1] |
|||
lin = nn.Linear(dims_sdf[layer], out_dim) |
|||
if geometric_init: |
|||
if layer == self.sdf_layers - 2: |
|||
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001) |
|||
torch.nn.init.constant_(lin.bias, -radius_init) |
|||
else: |
|||
torch.nn.init.constant_(lin.bias, 0.0) |
|||
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|||
setattr(self, "sdf_"+str(layer), lin) |
|||
if geometric_init: |
|||
if beta > 0: |
|||
self.activation = nn.Softplus(beta=beta) |
|||
# vanilla relu |
|||
else: |
|||
self.activation = nn.ReLU() |
|||
else: |
|||
#siren |
|||
self.activation = Sine() |
|||
self.final_activation = nn.ReLU() |
|||
|
|||
# composite f_i to h |
|||
|
|||
def forward(self, x: Tensor) -> Tensor: |
|||
""" |
|||
前向传播 |
|||
|
|||
|
|||
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
|||
''' |
|||
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵 |
|||
B: 批大小 |
|||
P: patch volume数量 |
|||
D: 特征维度 |
|||
:return: |
|||
f_i: 各patch的SDF值 (B, P) |
|||
''' |
|||
B, P, D = feature_matrix.shape |
|||
|
|||
# 展平处理 (B*P, D) |
|||
x = feature_matrix.view(-1, D) |
|||
|
|||
for layer in range(0, self.sdf_layers - 1): |
|||
lin = getattr(self, "sdf_" + str(layer)) |
|||
if layer in self.skip_in: |
|||
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' |
|||
|
|||
x = lin(x) |
|||
if layer < self.sdf_layers - 2: |
|||
x = self.activation(x) |
|||
output_value = x #all f_i |
|||
|
|||
# 恢复维度 (B, P) |
|||
f_i = output_value.view(B, P) |
|||
|
|||
return f_i |
|||
|
|||
|
|||
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h |
|||
#注意考虑如何批量处理 (B, P) 和 [csg tree] |
|||
class CSGCombiner: |
|||
def __init__(self, flag_convex: bool = True, rho: float = 0.05): |
|||
self.flag_convex = flag_convex |
|||
self.rho = rho |
|||
|
|||
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor: |
|||
''' |
|||
:param f_i: 形状为 (B, P) 的各patch SDF值 |
|||
:param csg_tree: CSG树结构 |
|||
:return: 组合后的整体SDF (B,) |
|||
''' |
|||
logger.info("\n".join(f"第{i}个csg: {t}" for i,t in enumerate(csg_tree))) |
|||
B = f_i.shape[0] |
|||
results = [] |
|||
for i in range(B): |
|||
# 处理每个样本的CSG组合 |
|||
h = self.nested_cvx_output_soft_blend( |
|||
f_i[i].unsqueeze(0), |
|||
csg_tree, |
|||
self.flag_convex |
|||
) |
|||
results.append(h) |
|||
|
|||
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,) |
|||
|
|||
def nested_cvx_output_soft_blend( |
|||
self, |
|||
value_matrix: torch.Tensor, |
|||
list_operation: List[Union[int, List]], |
|||
cvx_flag: bool = True |
|||
) -> torch.Tensor: |
|||
list_value = [] |
|||
for v in list_operation: |
|||
if not isinstance(v, list): |
|||
list_value.append(v) |
|||
|
|||
op_mat = torch.zeros(value_matrix.shape[1], len(list_value), |
|||
device=value_matrix.device) |
|||
for i in range(len(list_value)): |
|||
op_mat[list_value[i]][i] = 1.0 |
|||
|
|||
mat_mul = torch.matmul(value_matrix, op_mat) |
|||
if len(list_operation) == len(list_value): |
|||
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \ |
|||
else self.min_soft_blend(mat_mul, self.rho) |
|||
|
|||
参数: |
|||
x: 输入张量 |
|||
返回: |
|||
输出张量 |
|||
""" |
|||
# 第一层 |
|||
h = F.relu(self.fc1(x)) |
|||
# 第二层 |
|||
h = F.relu(self.fc2(h)) |
|||
# 输出层 |
|||
out = self.fc3(h) |
|||
list_output = [mat_mul] |
|||
for v in list_operation: |
|||
if isinstance(v, list): |
|||
list_output.append( |
|||
self.nested_cvx_output_soft_blend( |
|||
value_matrix, v, not cvx_flag |
|||
) |
|||
) |
|||
|
|||
return out |
|||
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \ |
|||
else self.min_soft_blend(torch.cat(list_output, 1), self.rho) |
|||
|
|||
def min_soft_blend(self, mat, rho): |
|||
res = mat[:,0] |
|||
for i in range(1, mat.shape[1]): |
|||
srho = res * res + mat[:,i] * mat[:,i] - rho * rho |
|||
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) |
|||
return res.unsqueeze(1) |
|||
|
|||
def max_soft_blend(self, mat, rho): |
|||
res = mat[:,0] |
|||
for i in range(1, mat.shape[1]): |
|||
srho = res * res + mat[:,i] * mat[:,i] - rho * rho |
|||
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) |
|||
return res.unsqueeze(1) |
|||
|
|||
|
|||
def test_csg_combiner(): |
|||
# 测试数据 (B=3, P=5) |
|||
f_i = torch.tensor([ |
|||
[1.0, 2.0, 3.0, 4.0, 5.0], |
|||
[0.5, 1.5, 2.5, 3.5, 4.5], |
|||
[-1.0, 0.0, 1.0, 2.0, 3.0] |
|||
]) |
|||
|
|||
# 每个样本使用不同的CSG树结构 |
|||
csg_trees = [ |
|||
[0, [1, 2]], # 使用索引0,1,2 |
|||
[[0, 1], 3], # 使用索引0,1,3 |
|||
[0, 1, [2, 4]] # 使用索引0,1,2,4 |
|||
] |
|||
|
|||
# 验证所有索引都有效 |
|||
P = f_i.shape[1] |
|||
for i, tree in enumerate(csg_trees): |
|||
def check_indices(node): |
|||
if isinstance(node, list): |
|||
for n in node: |
|||
check_indices(n) |
|||
else: |
|||
assert node < P, f"样本{i}的树包含无效索引{node},P={P}" |
|||
check_indices(tree) |
|||
|
|||
print("Input SDF values:") |
|||
print(f_i) |
|||
print("\nCSG Trees:") |
|||
for i, tree in enumerate(csg_trees): |
|||
print(f"Sample {i}: {tree}") |
|||
|
|||
# 测试凸组合 |
|||
print("\nTesting convex combination:") |
|||
combiner_convex = CSGCombiner(flag_convex=True) |
|||
h_convex = combiner_convex.forward(f_i, csg_trees) |
|||
print("Results:", h_convex) |
|||
|
|||
# 测试凹组合 |
|||
print("\nTesting concave combination:") |
|||
combiner_concave = CSGCombiner(flag_convex=False) |
|||
h_concave = combiner_concave.forward(f_i, csg_trees) |
|||
print("Results:", h_concave) |
|||
|
|||
# 测试不同rho值的软混合 |
|||
print("\nTesting soft blends:") |
|||
for rho in [0.01, 0.1, 0.5]: |
|||
combiner_soft = CSGCombiner(flag_convex=True, rho=rho) |
|||
h_soft = combiner_soft.forward(f_i, csg_trees) |
|||
print(f"rho={rho}:", h_soft) |
|||
|
|||
if __name__ == "__main__": |
|||
test_csg_combiner() |
|||
@ -1,4 +1,69 @@ |
|||
import torch |
|||
|
|||
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") |
|||
print(model) |
|||
from typing import List, Tuple |
|||
|
|||
def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor: |
|||
''' |
|||
args: |
|||
surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 |
|||
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 |
|||
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 |
|||
return: |
|||
intersect_mask: [8, N] - 布尔掩码,表示每个子包围盒与选择的包围盒是否相交。 |
|||
''' |
|||
# 提取选中的边界框 |
|||
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] |
|||
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] |
|||
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] |
|||
|
|||
# 确保广播机制正常工作 |
|||
intersect_mask = torch.all( |
|||
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] |
|||
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] |
|||
dim=-1 |
|||
) # 最终形状为 [8, N] |
|||
|
|||
return intersect_mask |
|||
|
|||
# 测试程序 |
|||
if __name__ == "__main__": |
|||
# 构造输入数据 |
|||
surf_bboxes = torch.tensor([ |
|||
[0, 0, 0, 1, 1, 1], # 立方体 1 |
|||
[0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2 |
|||
[2, 2, 2, 3, 3, 3] # 立方体 3 |
|||
]) # [B=3, 6] |
|||
|
|||
indices = torch.tensor([0, 1]) # 选择前两个立方体 |
|||
|
|||
# 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框 |
|||
parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2]) |
|||
center = (parent_bbox[:3] + parent_bbox[3:]) / 2 |
|||
child_bboxes = torch.tensor([ |
|||
[parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前 |
|||
[center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前 |
|||
[parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前 |
|||
[center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前 |
|||
[parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后 |
|||
[center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后 |
|||
[parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后 |
|||
[center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后 |
|||
]) # [8, 6] |
|||
|
|||
# 调用函数 |
|||
intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes) |
|||
|
|||
# 输出结果 |
|||
print("Intersect Mask:") |
|||
print(intersect_mask) |
|||
|
|||
# 将布尔掩码转换为索引列表 |
|||
child_indices = [] |
|||
for i in range(8): # 遍历每个子节点 |
|||
intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引 |
|||
child_indices.append(intersecting_faces) |
|||
|
|||
# 打印每个子节点对应的相交索引 |
|||
print("\nChild Indices:") |
|||
for i, indices in enumerate(child_indices): |
|||
print(f"Child {i}: {indices}") |
|||
Loading…
Reference in new issue