Browse Source

fix: 由于设置的采样点太小,取消Encoder1D的池化操作

main
mckay 4 months ago
parent
commit
c5bbc29f6a
  1. 39
      brep2sdf/networks/encoder.py

39
brep2sdf/networks/encoder.py

@ -77,6 +77,45 @@ class UNetMidBlock1D(nn.Module):
return x
class Encoder1D(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 256,
block_out_channels: Tuple[int] = (64, 128, 256),
layers_per_block: int = 2, # 添加这个参数但不使用
):
super().__init__()
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 1) # 使用1x1卷积
self.blocks = nn.ModuleList([])
in_ch = block_out_channels[0]
for out_ch in block_out_channels:
block = nn.Sequential(
nn.Conv1d(in_ch, out_ch, 1), # 1x1卷积
nn.BatchNorm1d(out_ch),
nn.ReLU(inplace=True)
)
self.blocks.append(block)
in_ch = out_ch
# 添加中间块和输出卷积
self.mid_block = nn.Sequential(
nn.Conv1d(block_out_channels[-1], block_out_channels[-1], 1),
nn.BatchNorm1d(block_out_channels[-1]),
nn.ReLU(inplace=True)
)
self.conv_out = nn.Conv1d(block_out_channels[-1], out_channels, 1)
def forward(self, x):
x = self.conv_in(x)
for block in self.blocks: # 使用self.blocks而不是self.down_blocks
x = block(x)
x = self.mid_block(x)
x = self.conv_out(x)
return x
class Encoder1D_(nn.Module):
"""一维编码器"""
def __init__(
self,

Loading…
Cancel
Save