11 changed files with 780 additions and 310 deletions
@ -1,42 +1,219 @@ |
|||||
import torch.nn as nn |
import torch.nn as nn |
||||
import torch.nn.functional as F |
import torch |
||||
from torch import Tensor |
import numpy as np |
||||
|
from typing import Tuple, List, Union |
||||
|
from brep2sdf.utils.logger import logger |
||||
|
|
||||
|
class Sine(nn.Module): |
||||
|
def __init(self): |
||||
|
super().__init__() |
||||
|
|
||||
|
def forward(self, input): |
||||
|
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 |
||||
|
return torch.sin(30 * input) |
||||
|
|
||||
class Decoder(nn.Module): |
class Decoder(nn.Module): |
||||
def __init__(self, |
def __init__( |
||||
input_dim: int, |
self, |
||||
output_dim: int, |
d_in: int, |
||||
hidden_dim: int = 256) : |
dims_sdf: List[int], |
||||
""" |
skip_in: Tuple[int, ...] = (), |
||||
最简单的Decoder实现 |
flag_convex: bool = True, |
||||
|
geometric_init: bool = True, |
||||
参数: |
radius_init: float = 1, |
||||
input_dim: 输入维度 |
beta: float = 100, |
||||
output_dim: 输出维度 |
) -> None: |
||||
hidden_dim: 隐藏层维度 (默认: 256) |
|
||||
""" |
|
||||
super().__init__() |
super().__init__() |
||||
|
|
||||
# 三层全连接网络 |
self.flag_convex = flag_convex |
||||
self.fc1 = nn.Linear(input_dim, hidden_dim) |
self.skip_in = skip_in |
||||
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
||||
self.fc3 = nn.Linear(hidden_dim, output_dim) |
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch] |
||||
|
self.sdf_layers = len(dims_sdf) |
||||
def forward(self, x: Tensor) -> Tensor: |
for layer in range(0, len(dims_sdf) - 1): |
||||
""" |
if layer + 1 in skip_in: |
||||
前向传播 |
out_dim = dims_sdf[layer + 1] - d_in |
||||
|
else: |
||||
参数: |
out_dim = dims_sdf[layer + 1] |
||||
x: 输入张量 |
lin = nn.Linear(dims_sdf[layer], out_dim) |
||||
返回: |
if geometric_init: |
||||
输出张量 |
if layer == self.sdf_layers - 2: |
||||
""" |
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001) |
||||
# 第一层 |
torch.nn.init.constant_(lin.bias, -radius_init) |
||||
h = F.relu(self.fc1(x)) |
else: |
||||
# 第二层 |
torch.nn.init.constant_(lin.bias, 0.0) |
||||
h = F.relu(self.fc2(h)) |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
||||
# 输出层 |
setattr(self, "sdf_"+str(layer), lin) |
||||
out = self.fc3(h) |
if geometric_init: |
||||
|
if beta > 0: |
||||
return out |
self.activation = nn.Softplus(beta=beta) |
||||
|
# vanilla relu |
||||
|
else: |
||||
|
self.activation = nn.ReLU() |
||||
|
else: |
||||
|
#siren |
||||
|
self.activation = Sine() |
||||
|
self.final_activation = nn.ReLU() |
||||
|
|
||||
|
# composite f_i to h |
||||
|
|
||||
|
|
||||
|
|
||||
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor: |
||||
|
''' |
||||
|
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵 |
||||
|
B: 批大小 |
||||
|
P: patch volume数量 |
||||
|
D: 特征维度 |
||||
|
:return: |
||||
|
f_i: 各patch的SDF值 (B, P) |
||||
|
''' |
||||
|
B, P, D = feature_matrix.shape |
||||
|
|
||||
|
# 展平处理 (B*P, D) |
||||
|
x = feature_matrix.view(-1, D) |
||||
|
|
||||
|
for layer in range(0, self.sdf_layers - 1): |
||||
|
lin = getattr(self, "sdf_" + str(layer)) |
||||
|
if layer in self.skip_in: |
||||
|
x = torch.cat([x, x], -1) / np.sqrt(2) # Fix undefined 'input' |
||||
|
|
||||
|
x = lin(x) |
||||
|
if layer < self.sdf_layers - 2: |
||||
|
x = self.activation(x) |
||||
|
output_value = x #all f_i |
||||
|
|
||||
|
# 恢复维度 (B, P) |
||||
|
f_i = output_value.view(B, P) |
||||
|
|
||||
|
return f_i |
||||
|
|
||||
|
|
||||
|
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h |
||||
|
#注意考虑如何批量处理 (B, P) 和 [csg tree] |
||||
|
class CSGCombiner: |
||||
|
def __init__(self, flag_convex: bool = True, rho: float = 0.05): |
||||
|
self.flag_convex = flag_convex |
||||
|
self.rho = rho |
||||
|
|
||||
|
def forward(self, f_i: torch.Tensor, csg_tree) -> torch.Tensor: |
||||
|
''' |
||||
|
:param f_i: 形状为 (B, P) 的各patch SDF值 |
||||
|
:param csg_tree: CSG树结构 |
||||
|
:return: 组合后的整体SDF (B,) |
||||
|
''' |
||||
|
logger.info("\n".join(f"第{i}个csg: {t}" for i,t in enumerate(csg_tree))) |
||||
|
B = f_i.shape[0] |
||||
|
results = [] |
||||
|
for i in range(B): |
||||
|
# 处理每个样本的CSG组合 |
||||
|
h = self.nested_cvx_output_soft_blend( |
||||
|
f_i[i].unsqueeze(0), |
||||
|
csg_tree, |
||||
|
self.flag_convex |
||||
|
) |
||||
|
results.append(h) |
||||
|
|
||||
|
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,) |
||||
|
|
||||
|
def nested_cvx_output_soft_blend( |
||||
|
self, |
||||
|
value_matrix: torch.Tensor, |
||||
|
list_operation: List[Union[int, List]], |
||||
|
cvx_flag: bool = True |
||||
|
) -> torch.Tensor: |
||||
|
list_value = [] |
||||
|
for v in list_operation: |
||||
|
if not isinstance(v, list): |
||||
|
list_value.append(v) |
||||
|
|
||||
|
op_mat = torch.zeros(value_matrix.shape[1], len(list_value), |
||||
|
device=value_matrix.device) |
||||
|
for i in range(len(list_value)): |
||||
|
op_mat[list_value[i]][i] = 1.0 |
||||
|
|
||||
|
mat_mul = torch.matmul(value_matrix, op_mat) |
||||
|
if len(list_operation) == len(list_value): |
||||
|
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \ |
||||
|
else self.min_soft_blend(mat_mul, self.rho) |
||||
|
|
||||
|
list_output = [mat_mul] |
||||
|
for v in list_operation: |
||||
|
if isinstance(v, list): |
||||
|
list_output.append( |
||||
|
self.nested_cvx_output_soft_blend( |
||||
|
value_matrix, v, not cvx_flag |
||||
|
) |
||||
|
) |
||||
|
|
||||
|
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \ |
||||
|
else self.min_soft_blend(torch.cat(list_output, 1), self.rho) |
||||
|
|
||||
|
def min_soft_blend(self, mat, rho): |
||||
|
res = mat[:,0] |
||||
|
for i in range(1, mat.shape[1]): |
||||
|
srho = res * res + mat[:,i] * mat[:,i] - rho * rho |
||||
|
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) |
||||
|
return res.unsqueeze(1) |
||||
|
|
||||
|
def max_soft_blend(self, mat, rho): |
||||
|
res = mat[:,0] |
||||
|
for i in range(1, mat.shape[1]): |
||||
|
srho = res * res + mat[:,i] * mat[:,i] - rho * rho |
||||
|
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs())) |
||||
|
return res.unsqueeze(1) |
||||
|
|
||||
|
|
||||
|
def test_csg_combiner(): |
||||
|
# 测试数据 (B=3, P=5) |
||||
|
f_i = torch.tensor([ |
||||
|
[1.0, 2.0, 3.0, 4.0, 5.0], |
||||
|
[0.5, 1.5, 2.5, 3.5, 4.5], |
||||
|
[-1.0, 0.0, 1.0, 2.0, 3.0] |
||||
|
]) |
||||
|
|
||||
|
# 每个样本使用不同的CSG树结构 |
||||
|
csg_trees = [ |
||||
|
[0, [1, 2]], # 使用索引0,1,2 |
||||
|
[[0, 1], 3], # 使用索引0,1,3 |
||||
|
[0, 1, [2, 4]] # 使用索引0,1,2,4 |
||||
|
] |
||||
|
|
||||
|
# 验证所有索引都有效 |
||||
|
P = f_i.shape[1] |
||||
|
for i, tree in enumerate(csg_trees): |
||||
|
def check_indices(node): |
||||
|
if isinstance(node, list): |
||||
|
for n in node: |
||||
|
check_indices(n) |
||||
|
else: |
||||
|
assert node < P, f"样本{i}的树包含无效索引{node},P={P}" |
||||
|
check_indices(tree) |
||||
|
|
||||
|
print("Input SDF values:") |
||||
|
print(f_i) |
||||
|
print("\nCSG Trees:") |
||||
|
for i, tree in enumerate(csg_trees): |
||||
|
print(f"Sample {i}: {tree}") |
||||
|
|
||||
|
# 测试凸组合 |
||||
|
print("\nTesting convex combination:") |
||||
|
combiner_convex = CSGCombiner(flag_convex=True) |
||||
|
h_convex = combiner_convex.forward(f_i, csg_trees) |
||||
|
print("Results:", h_convex) |
||||
|
|
||||
|
# 测试凹组合 |
||||
|
print("\nTesting concave combination:") |
||||
|
combiner_concave = CSGCombiner(flag_convex=False) |
||||
|
h_concave = combiner_concave.forward(f_i, csg_trees) |
||||
|
print("Results:", h_concave) |
||||
|
|
||||
|
# 测试不同rho值的软混合 |
||||
|
print("\nTesting soft blends:") |
||||
|
for rho in [0.01, 0.1, 0.5]: |
||||
|
combiner_soft = CSGCombiner(flag_convex=True, rho=rho) |
||||
|
h_soft = combiner_soft.forward(f_i, csg_trees) |
||||
|
print(f"rho={rho}:", h_soft) |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
test_csg_combiner() |
||||
@ -1,4 +1,69 @@ |
|||||
import torch |
import torch |
||||
|
|
||||
model = torch.jit.load("/home/wch/brep2sdf/data/output_data/00000054.pt") |
from typing import List, Tuple |
||||
print(model) |
|
||||
|
def bbox_intersect(surf_bboxes: torch.Tensor, indices: torch.Tensor, child_bboxes: torch.Tensor) -> torch.Tensor: |
||||
|
''' |
||||
|
args: |
||||
|
surf_bboxes: [B, 6] - 表示多个包围盒的张量,每个包围盒由其最小和最大坐标定义。 |
||||
|
indices: 选择bbox, [N], N <= B - 用于选择特定包围盒的索引张量。 |
||||
|
child_bboxes: [8, 6] - 一个包围盒被分成八个子包围盒后的结果。 |
||||
|
return: |
||||
|
intersect_mask: [8, N] - 布尔掩码,表示每个子包围盒与选择的包围盒是否相交。 |
||||
|
''' |
||||
|
# 提取选中的边界框 |
||||
|
selected_bboxes = surf_bboxes[indices] # 形状为 [N, 6] |
||||
|
min1, max1 = selected_bboxes[:, :3], selected_bboxes[:, 3:] # 形状为 [N, 3] |
||||
|
min2, max2 = child_bboxes[:, :3], child_bboxes[:, 3:] # 形状为 [8, 3] |
||||
|
|
||||
|
# 确保广播机制正常工作 |
||||
|
intersect_mask = torch.all( |
||||
|
(max1.unsqueeze(0) >= min2.unsqueeze(1)) & # 形状为 [8, N, 3] |
||||
|
(max2.unsqueeze(1) >= min1.unsqueeze(0)), # 形状为 [8, N, 3] |
||||
|
dim=-1 |
||||
|
) # 最终形状为 [8, N] |
||||
|
|
||||
|
return intersect_mask |
||||
|
|
||||
|
# 测试程序 |
||||
|
if __name__ == "__main__": |
||||
|
# 构造输入数据 |
||||
|
surf_bboxes = torch.tensor([ |
||||
|
[0, 0, 0, 1, 1, 1], # 立方体 1 |
||||
|
[0.5, 0.5, 0.5, 1.5, 1.5, 1.5], # 立方体 2 |
||||
|
[2, 2, 2, 3, 3, 3] # 立方体 3 |
||||
|
]) # [B=3, 6] |
||||
|
|
||||
|
indices = torch.tensor([0, 1]) # 选择前两个立方体 |
||||
|
|
||||
|
# 假设父边界框为 [0, 0, 0, 2, 2, 2],生成其八个子边界框 |
||||
|
parent_bbox = torch.tensor([0, 0, 0, 2, 2, 2]) |
||||
|
center = (parent_bbox[:3] + parent_bbox[3:]) / 2 |
||||
|
child_bboxes = torch.tensor([ |
||||
|
[parent_bbox[0], parent_bbox[1], parent_bbox[2], center[0], center[1], center[2]], # 左下前 |
||||
|
[center[0], parent_bbox[1], parent_bbox[2], parent_bbox[3], center[1], center[2]], # 右下前 |
||||
|
[parent_bbox[0], center[1], parent_bbox[2], center[0], parent_bbox[4], center[2]], # 左上前 |
||||
|
[center[0], center[1], parent_bbox[2], parent_bbox[3], parent_bbox[4], center[2]], # 右上前 |
||||
|
[parent_bbox[0], parent_bbox[1], center[2], center[0], center[1], parent_bbox[5]], # 左下后 |
||||
|
[center[0], parent_bbox[1], center[2], parent_bbox[3], center[1], parent_bbox[5]], # 右下后 |
||||
|
[parent_bbox[0], center[1], center[2], center[0], parent_bbox[4], parent_bbox[5]], # 左上后 |
||||
|
[center[0], center[1], center[2], parent_bbox[3], parent_bbox[4], parent_bbox[5]] # 右上后 |
||||
|
]) # [8, 6] |
||||
|
|
||||
|
# 调用函数 |
||||
|
intersect_mask = bbox_intersect(surf_bboxes, indices, child_bboxes) |
||||
|
|
||||
|
# 输出结果 |
||||
|
print("Intersect Mask:") |
||||
|
print(intersect_mask) |
||||
|
|
||||
|
# 将布尔掩码转换为索引列表 |
||||
|
child_indices = [] |
||||
|
for i in range(8): # 遍历每个子节点 |
||||
|
intersecting_faces = indices[intersect_mask[i]] # 获取当前子节点的相交面片索引 |
||||
|
child_indices.append(intersecting_faces) |
||||
|
|
||||
|
# 打印每个子节点对应的相交索引 |
||||
|
print("\nChild Indices:") |
||||
|
for i, indices in enumerate(child_indices): |
||||
|
print(f"Child {i}: {indices}") |
||||
Loading…
Reference in new issue