From c5bbc29f6afe5409313b6f1af94d6a4fc606c2b1 Mon Sep 17 00:00:00 2001 From: mckay Date: Fri, 22 Nov 2024 00:51:19 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=94=B1=E4=BA=8E=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E7=9A=84=E9=87=87=E6=A0=B7=E7=82=B9=E5=A4=AA=E5=B0=8F=EF=BC=8C?= =?UTF-8?q?=E5=8F=96=E6=B6=88Encoder1D=E7=9A=84=E6=B1=A0=E5=8C=96=E6=93=8D?= =?UTF-8?q?=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- brep2sdf/networks/encoder.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/brep2sdf/networks/encoder.py b/brep2sdf/networks/encoder.py index 2f336a3..38eb340 100644 --- a/brep2sdf/networks/encoder.py +++ b/brep2sdf/networks/encoder.py @@ -77,6 +77,45 @@ class UNetMidBlock1D(nn.Module): return x class Encoder1D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 256, + block_out_channels: Tuple[int] = (64, 128, 256), + layers_per_block: int = 2, # 添加这个参数但不使用 + ): + super().__init__() + self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 1) # 使用1x1卷积 + + self.blocks = nn.ModuleList([]) + in_ch = block_out_channels[0] + for out_ch in block_out_channels: + block = nn.Sequential( + nn.Conv1d(in_ch, out_ch, 1), # 1x1卷积 + nn.BatchNorm1d(out_ch), + nn.ReLU(inplace=True) + ) + self.blocks.append(block) + in_ch = out_ch + + # 添加中间块和输出卷积 + self.mid_block = nn.Sequential( + nn.Conv1d(block_out_channels[-1], block_out_channels[-1], 1), + nn.BatchNorm1d(block_out_channels[-1]), + nn.ReLU(inplace=True) + ) + + self.conv_out = nn.Conv1d(block_out_channels[-1], out_channels, 1) + + def forward(self, x): + x = self.conv_in(x) + for block in self.blocks: # 使用self.blocks而不是self.down_blocks + x = block(x) + x = self.mid_block(x) + x = self.conv_out(x) + return x + +class Encoder1D_(nn.Module): """一维编码器""" def __init__( self,