You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
8.6 KiB
255 lines
8.6 KiB
import torch.nn as nn
|
|
import torch
|
|
import numpy as np
|
|
from typing import Tuple, List, Union
|
|
from brep2sdf.utils.logger import logger
|
|
|
|
class Sine(nn.Module):
|
|
def __init(self):
|
|
super().__init__()
|
|
|
|
def forward(self, input):
|
|
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
|
|
return torch.sin(30 * input)
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_in: int,
|
|
dims_sdf: List[int],
|
|
skip_in: Tuple[int, ...] = (),
|
|
flag_convex: bool = True,
|
|
geometric_init: bool = True,
|
|
radius_init: float = 0.5,
|
|
beta: float = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.flag_convex = flag_convex
|
|
self.skip_in = skip_in
|
|
|
|
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch]
|
|
self.sdf_layers = len(dims_sdf)
|
|
|
|
# 使用 ModuleList 存储 sdf 层
|
|
self.sdf_modules = nn.ModuleList()
|
|
for layer in range(0, len(dims_sdf) - 1):
|
|
if layer + 1 in skip_in:
|
|
out_dim = dims_sdf[layer + 1] - d_in
|
|
else:
|
|
out_dim = dims_sdf[layer + 1]
|
|
lin = nn.Linear(dims_sdf[layer], out_dim)
|
|
if geometric_init:
|
|
if layer == self.sdf_layers - 2:
|
|
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001)
|
|
torch.nn.init.constant_(lin.bias, -radius_init)
|
|
else:
|
|
torch.nn.init.constant_(lin.bias, 0.0)
|
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
|
self.sdf_modules.append(lin)
|
|
|
|
# 添加层归一化
|
|
self.norm_layers = nn.ModuleList()
|
|
for dim in dims_sdf[1:-1]:
|
|
self.norm_layers.append(nn.LayerNorm(dim))
|
|
|
|
if geometric_init:
|
|
self.activation = nn.Sequential(
|
|
nn.LayerNorm(out_dim), # 添加层归一化
|
|
nn.Softplus(beta=beta)
|
|
)
|
|
if beta > 0:
|
|
self.activation = nn.Softplus(beta=beta)
|
|
# vanilla relu
|
|
else:
|
|
self.activation = nn.ReLU()
|
|
else:
|
|
#siren
|
|
self.activation = Sine()
|
|
self.final_activation = nn.Tanh()
|
|
|
|
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
|
|
'''
|
|
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵
|
|
B: 批大小
|
|
P: patch volume数量
|
|
D: 特征维度
|
|
:return:
|
|
f_i: 各patch的SDF值 (B, P)
|
|
'''
|
|
B, P, D = feature_matrix.shape
|
|
|
|
# 展平处理 (B*P, D)
|
|
x = feature_matrix.view(-1, D)
|
|
|
|
# 使用枚举遍历 sdf_modules
|
|
for layer, lin in enumerate(self.sdf_modules):
|
|
if layer in self.skip_in:
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
|
|
#logger.print_tensor_stats(f"layer-{layer}>x", x)
|
|
x = lin(x)
|
|
if layer < self.sdf_layers - 2:
|
|
x = self.activation(x)
|
|
output_value = x #all f_i
|
|
|
|
# 恢复维度 (B, P)
|
|
f_i = output_value.view(B, P)
|
|
|
|
return f_i
|
|
|
|
@torch.jit.export
|
|
def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor:
|
|
'''
|
|
:param feature_matrix: 形状为(S, D) 的特征矩阵
|
|
S: 采样数量
|
|
D: 特征维度
|
|
:return:
|
|
f: 各patch的SDF值 (S)
|
|
'''
|
|
# 直接使用输入的特征矩阵,因为形状已经是 (S, D)
|
|
x = feature_matrix
|
|
|
|
for layer, lin in enumerate(self.sdf_modules):
|
|
if layer in self.skip_in:
|
|
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
|
|
#logger.print_tensor_stats(f"layer-{layer}>x", x)
|
|
x = lin(x)
|
|
if layer < self.sdf_layers - 2:
|
|
x = self.activation(x)
|
|
|
|
output_value = x # 所有 f 的值
|
|
# 调整输出形状为 (S)
|
|
f = output_value.squeeze(-1)
|
|
|
|
return f
|
|
"""
|
|
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
|
|
#注意考虑如何批量处理 (B, P) 和 [csg tree]
|
|
class CSGCombiner:
|
|
def __init__(self, flag_convex: bool = True, rho: float = 0.05):
|
|
self.flag_convex = flag_convex
|
|
self.rho = rho
|
|
|
|
def forward(self, f_i: torch.Tensor, csg_tree
|
|
) -> torch.Tensor:
|
|
'''
|
|
:param f_i: 形状为 (B, P) 的各patch SDF值
|
|
:param csg_tree: CSG树结构
|
|
:return: 组合后的整体SDF (B,)
|
|
'''
|
|
logger.info("\n".join(f"第{i}个csg: {t}" for i,t in enumerate(csg_tree)))
|
|
B = f_i.shape[0]
|
|
results = []
|
|
for i in range(B):
|
|
# 处理每个样本的CSG组合
|
|
h = self.nested_cvx_output_soft_blend(
|
|
f_i[i].unsqueeze(0),
|
|
csg_tree,
|
|
self.flag_convex
|
|
)
|
|
results.append(h)
|
|
|
|
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,)
|
|
|
|
def nested_cvx_output_soft_blend(
|
|
self,
|
|
value_matrix: torch.Tensor,
|
|
list_operation: List[Union[int, List]],
|
|
cvx_flag: bool = True
|
|
) -> torch.Tensor:
|
|
list_value = []
|
|
for v in list_operation:
|
|
if not isinstance(v, list):
|
|
list_value.append(v)
|
|
|
|
op_mat = torch.zeros(value_matrix.shape[1], len(list_value),
|
|
device=value_matrix.device)
|
|
for i in range(len(list_value)):
|
|
op_mat[list_value[i]][i] = 1.0
|
|
|
|
mat_mul = torch.matmul(value_matrix, op_mat)
|
|
if len(list_operation) == len(list_value):
|
|
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \
|
|
else self.min_soft_blend(mat_mul, self.rho)
|
|
|
|
list_output = [mat_mul]
|
|
for v in list_operation:
|
|
if isinstance(v, list):
|
|
list_output.append(
|
|
self.nested_cvx_output_soft_blend(
|
|
value_matrix, v, not cvx_flag
|
|
)
|
|
)
|
|
|
|
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \
|
|
else self.min_soft_blend(torch.cat(list_output, 1), self.rho)
|
|
|
|
def min_soft_blend(self, mat, rho):
|
|
res = mat[:,0]
|
|
for i in range(1, mat.shape[1]):
|
|
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
|
|
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
|
|
return res.unsqueeze(1)
|
|
|
|
def max_soft_blend(self, mat, rho):
|
|
res = mat[:,0]
|
|
for i in range(1, mat.shape[1]):
|
|
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
|
|
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
|
|
return res.unsqueeze(1)
|
|
|
|
|
|
def test_csg_combiner():
|
|
# 测试数据 (B=3, P=5)
|
|
f_i = torch.tensor([
|
|
[1.0, 2.0, 3.0, 4.0, 5.0],
|
|
[0.5, 1.5, 2.5, 3.5, 4.5],
|
|
[-1.0, 0.0, 1.0, 2.0, 3.0]
|
|
])
|
|
|
|
# 每个样本使用不同的CSG树结构
|
|
csg_trees = [
|
|
[0, [1, 2]], # 使用索引0,1,2
|
|
[[0, 1], 3], # 使用索引0,1,3
|
|
[0, 1, [2, 4]] # 使用索引0,1,2,4
|
|
]
|
|
|
|
# 验证所有索引都有效
|
|
P = f_i.shape[1]
|
|
for i, tree in enumerate(csg_trees):
|
|
def check_indices(node):
|
|
if isinstance(node, list):
|
|
for n in node:
|
|
check_indices(n)
|
|
else:
|
|
assert node < P, f"样本{i}的树包含无效索引{node},P={P}"
|
|
check_indices(tree)
|
|
|
|
print("Input SDF values:")
|
|
print(f_i)
|
|
print("\nCSG Trees:")
|
|
for i, tree in enumerate(csg_trees):
|
|
print(f"Sample {i}: {tree}")
|
|
|
|
# 测试凸组合
|
|
print("\nTesting convex combination:")
|
|
combiner_convex = CSGCombiner(flag_convex=True)
|
|
h_convex = combiner_convex.forward(f_i, csg_trees)
|
|
print("Results:", h_convex)
|
|
|
|
# 测试凹组合
|
|
print("\nTesting concave combination:")
|
|
combiner_concave = CSGCombiner(flag_convex=False)
|
|
h_concave = combiner_concave.forward(f_i, csg_trees)
|
|
print("Results:", h_concave)
|
|
|
|
# 测试不同rho值的软混合
|
|
print("\nTesting soft blends:")
|
|
for rho in [0.01, 0.1, 0.5]:
|
|
combiner_soft = CSGCombiner(flag_convex=True, rho=rho)
|
|
h_soft = combiner_soft.forward(f_i, csg_trees)
|
|
print(f"rho={rho}:", h_soft)
|
|
|
|
if __name__ == "__main__":
|
|
test_csg_combiner()
|
|
"""
|