You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

255 lines
8.6 KiB

import torch.nn as nn
import torch
import numpy as np
from typing import Tuple, List, Union
from brep2sdf.utils.logger import logger
class Sine(nn.Module):
def __init(self):
super().__init__()
def forward(self, input):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
return torch.sin(30 * input)
class Decoder(nn.Module):
def __init__(
self,
d_in: int,
dims_sdf: List[int],
skip_in: Tuple[int, ...] = (),
flag_convex: bool = True,
geometric_init: bool = True,
radius_init: float = 0.5,
beta: float = 100,
) -> None:
super().__init__()
self.flag_convex = flag_convex
self.skip_in = skip_in
dims_sdf = [d_in] + dims_sdf + [1] #[self.n_branch]
self.sdf_layers = len(dims_sdf)
# 使用 ModuleList 存储 sdf 层
self.sdf_modules = nn.ModuleList()
for layer in range(0, len(dims_sdf) - 1):
if layer + 1 in skip_in:
out_dim = dims_sdf[layer + 1] - d_in
else:
out_dim = dims_sdf[layer + 1]
lin = nn.Linear(dims_sdf[layer], out_dim)
if geometric_init:
if layer == self.sdf_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims_sdf[layer]), std=0.00001)
torch.nn.init.constant_(lin.bias, -radius_init)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
self.sdf_modules.append(lin)
# 添加层归一化
self.norm_layers = nn.ModuleList()
for dim in dims_sdf[1:-1]:
self.norm_layers.append(nn.LayerNorm(dim))
if geometric_init:
self.activation = nn.Sequential(
nn.LayerNorm(out_dim), # 添加层归一化
nn.Softplus(beta=beta)
)
if beta > 0:
self.activation = nn.Softplus(beta=beta)
# vanilla relu
else:
self.activation = nn.ReLU()
else:
#siren
self.activation = Sine()
self.final_activation = nn.Tanh()
def forward(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为 (B, P, D) 的特征矩阵
B: 批大小
P: patch volume数量
D: 特征维度
:return:
f_i: 各patch的SDF值 (B, P)
'''
B, P, D = feature_matrix.shape
# 展平处理 (B*P, D)
x = feature_matrix.view(-1, D)
# 使用枚举遍历 sdf_modules
for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
#logger.print_tensor_stats(f"layer-{layer}>x", x)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x #all f_i
# 恢复维度 (B, P)
f_i = output_value.view(B, P)
return f_i
@torch.jit.export
def forward_training_volumes(self, feature_matrix: torch.Tensor) -> torch.Tensor:
'''
:param feature_matrix: 形状为(S, D) 的特征矩阵
S: 采样数量
D: 特征维度
:return:
f: 各patch的SDF值 (S)
'''
# 直接使用输入的特征矩阵,因为形状已经是 (S, D)
x = feature_matrix
for layer, lin in enumerate(self.sdf_modules):
if layer in self.skip_in:
x = torch.cat([x, x], -1) / torch.sqrt(torch.tensor(2.0)) # 使用 torch.sqrt
#logger.print_tensor_stats(f"layer-{layer}>x", x)
x = lin(x)
if layer < self.sdf_layers - 2:
x = self.activation(x)
output_value = x # 所有 f 的值
# 调整输出形状为 (S)
f = output_value.squeeze(-1)
return f
"""
# 一个基础情形: 输入 fi 形状[P] 和 csg tree,凹凸组合输出h
#注意考虑如何批量处理 (B, P) 和 [csg tree]
class CSGCombiner:
def __init__(self, flag_convex: bool = True, rho: float = 0.05):
self.flag_convex = flag_convex
self.rho = rho
def forward(self, f_i: torch.Tensor, csg_tree
) -> torch.Tensor:
'''
:param f_i: 形状为 (B, P) 的各patch SDF值
:param csg_tree: CSG树结构
:return: 组合后的整体SDF (B,)
'''
logger.info("\n".join(f"{i}个csg: {t}" for i,t in enumerate(csg_tree)))
B = f_i.shape[0]
results = []
for i in range(B):
# 处理每个样本的CSG组合
h = self.nested_cvx_output_soft_blend(
f_i[i].unsqueeze(0),
csg_tree,
self.flag_convex
)
results.append(h)
return torch.cat(results, dim=0).squeeze(1) # 从(B,1)变为(B,)
def nested_cvx_output_soft_blend(
self,
value_matrix: torch.Tensor,
list_operation: List[Union[int, List]],
cvx_flag: bool = True
) -> torch.Tensor:
list_value = []
for v in list_operation:
if not isinstance(v, list):
list_value.append(v)
op_mat = torch.zeros(value_matrix.shape[1], len(list_value),
device=value_matrix.device)
for i in range(len(list_value)):
op_mat[list_value[i]][i] = 1.0
mat_mul = torch.matmul(value_matrix, op_mat)
if len(list_operation) == len(list_value):
return self.max_soft_blend(mat_mul, self.rho) if cvx_flag \
else self.min_soft_blend(mat_mul, self.rho)
list_output = [mat_mul]
for v in list_operation:
if isinstance(v, list):
list_output.append(
self.nested_cvx_output_soft_blend(
value_matrix, v, not cvx_flag
)
)
return self.max_soft_blend(torch.cat(list_output, 1), self.rho) if cvx_flag \
else self.min_soft_blend(torch.cat(list_output, 1), self.rho)
def min_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] - torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def max_soft_blend(self, mat, rho):
res = mat[:,0]
for i in range(1, mat.shape[1]):
srho = res * res + mat[:,i] * mat[:,i] - rho * rho
res = res + mat[:,i] + torch.sqrt(res * res + mat[:,i] * mat[:,i] + 1.0/(8 * rho * rho) * srho * (srho - srho.abs()))
return res.unsqueeze(1)
def test_csg_combiner():
# 测试数据 (B=3, P=5)
f_i = torch.tensor([
[1.0, 2.0, 3.0, 4.0, 5.0],
[0.5, 1.5, 2.5, 3.5, 4.5],
[-1.0, 0.0, 1.0, 2.0, 3.0]
])
# 每个样本使用不同的CSG树结构
csg_trees = [
[0, [1, 2]], # 使用索引0,1,2
[[0, 1], 3], # 使用索引0,1,3
[0, 1, [2, 4]] # 使用索引0,1,2,4
]
# 验证所有索引都有效
P = f_i.shape[1]
for i, tree in enumerate(csg_trees):
def check_indices(node):
if isinstance(node, list):
for n in node:
check_indices(n)
else:
assert node < P, f"样本{i}的树包含无效索引{node},P={P}"
check_indices(tree)
print("Input SDF values:")
print(f_i)
print("\nCSG Trees:")
for i, tree in enumerate(csg_trees):
print(f"Sample {i}: {tree}")
# 测试凸组合
print("\nTesting convex combination:")
combiner_convex = CSGCombiner(flag_convex=True)
h_convex = combiner_convex.forward(f_i, csg_trees)
print("Results:", h_convex)
# 测试凹组合
print("\nTesting concave combination:")
combiner_concave = CSGCombiner(flag_convex=False)
h_concave = combiner_concave.forward(f_i, csg_trees)
print("Results:", h_concave)
# 测试不同rho值的软混合
print("\nTesting soft blends:")
for rho in [0.01, 0.1, 0.5]:
combiner_soft = CSGCombiner(flag_convex=True, rho=rho)
h_soft = combiner_soft.forward(f_i, csg_trees)
print(f"rho={rho}:", h_soft)
if __name__ == "__main__":
test_csg_combiner()
"""