|
|
@ -77,6 +77,45 @@ class UNetMidBlock1D(nn.Module): |
|
|
|
return x |
|
|
|
|
|
|
|
class Encoder1D(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_channels: int = 3, |
|
|
|
out_channels: int = 256, |
|
|
|
block_out_channels: Tuple[int] = (64, 128, 256), |
|
|
|
layers_per_block: int = 2, # 添加这个参数但不使用 |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.conv_in = nn.Conv1d(in_channels, block_out_channels[0], 1) # 使用1x1卷积 |
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([]) |
|
|
|
in_ch = block_out_channels[0] |
|
|
|
for out_ch in block_out_channels: |
|
|
|
block = nn.Sequential( |
|
|
|
nn.Conv1d(in_ch, out_ch, 1), # 1x1卷积 |
|
|
|
nn.BatchNorm1d(out_ch), |
|
|
|
nn.ReLU(inplace=True) |
|
|
|
) |
|
|
|
self.blocks.append(block) |
|
|
|
in_ch = out_ch |
|
|
|
|
|
|
|
# 添加中间块和输出卷积 |
|
|
|
self.mid_block = nn.Sequential( |
|
|
|
nn.Conv1d(block_out_channels[-1], block_out_channels[-1], 1), |
|
|
|
nn.BatchNorm1d(block_out_channels[-1]), |
|
|
|
nn.ReLU(inplace=True) |
|
|
|
) |
|
|
|
|
|
|
|
self.conv_out = nn.Conv1d(block_out_channels[-1], out_channels, 1) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = self.conv_in(x) |
|
|
|
for block in self.blocks: # 使用self.blocks而不是self.down_blocks |
|
|
|
x = block(x) |
|
|
|
x = self.mid_block(x) |
|
|
|
x = self.conv_out(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class Encoder1D_(nn.Module): |
|
|
|
"""一维编码器""" |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|