diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 2f336a3..38eb340 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -77,6 +77,45 @@ class UNetMidBlock1D(nn.Module): return x class Encoder1D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 256, + block_out_channels: Tuple[int] = (64, 128, 256), + layers_per_block: int = 2, # 添加这个参数但不使用 + ): + super().__init__() + self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 1) # 使用1x1卷积 + + self.blocks = nn.ModuleList([]) + in_ch = block_out_channels[0] + for out_ch in block_out_channels: + block = nn.Sequential( + nn.Conv1d(in_ch, out_ch, 1), # 1x1卷积 + nn.BatchNorm1d(out_ch), + nn.ReLU(inplace=True) + ) + self.blocks.append(block) + in_ch = out_ch + + # 添加中间块和输出卷积 + self.mid_block = nn.Sequential( + nn.Conv1d(block_out_channels[-1], block_out_channels[-1], 1), + nn.BatchNorm1d(block_out_channels[-1]), + nn.ReLU(inplace=True) + ) + + self.conv_out = nn.Conv1d(block_out_channels[-1], out_channels, 1) + + def forward(self, x): + x = self.conv_in(x) + for block in self.blocks: # 使用self.blocks而不是self.down_blocks + x = block(x) + x = self.mid_block(x) + x = self.conv_out(x) + return x + +class Encoder1D_(nn.Module): """一维编码器""" def __init__( self,