PyTorch实现图像聚类方法

2021-01-29 16:22:59
图像自编码的步骤如下:
准备输入图像(左上角)
将图像输入编码器,由具有标准CNN和ReLU激活的卷积层(绿色)和最大池层(紫色)组成
得到一个低维的编码
将编码输入译码器,它由转置的卷积层(带归一化和ReLU激活)(浅绿色)和解池化层(浅紫色)加上一个没有归一化或激活的最终卷积层(黄色)
获得与输入尺寸相同的输出图像。
是时候把这个设计变成代码了。
我从创建一个编码器模块开始。第一行,包括初始化方法,如下所示:
import torch
from torch import nn
from torchvision import models
class EncoderVGG(nn.Module):
'''
基于vgg16体系结构的图像编码器,具有batch normalization。
Args:
预训练的params (bool,可选):是否应该用预训练的vGG参数填充网络,默认值为True
'''
channels_in = 3
channels_code = 512
def __init__(self, pretrained_params=True):
super(EncoderVGG, self).__init__()
vgg = models.vgg16_bn(pretrained=pretrained_params)
del vgg.classifier
del vgg.avgpool
self.encoder = self._encodify_(vgg)
编码器的结构与VGG-16卷积网络的特征提取层结构相同。因此,PyTorch库中很容易找到该部分—PyTorch models.vgg16_bn,请参阅代码片段中的第19行。
与VGG的规范应用程序不同,编码不会被输入到分类层中。最后两层vgg.classifier以及vgg.avgpool被丢弃。
编码器的层需要一次调整。在解码器的解池层中,编码器的最大池层中的池索引必须可用,在前面的图像中虚线箭头表示。VGG -16的模板版本不生成这些索引。然而,池化层可以重新初始化。这就是EncoderVGG模块的_encodify方法完成的工作。
def _encodify_(self, encoder):
'''
基于VGG模板的架构创建编码器模块列表。在编码器-解码器体系结构中,解码器中的解池操作需要来自编码器中相应池操作的池索引。在VGG模板中,这些索引不返回。因此需要使用此方法扩展池操作。
参数:
编码器:模板VGG模型
返回:
模块:定义与VGG模型对应的编码器的模块列表
'''
modules = nn.ModuleList()
for module in encoder.features:
if isinstance(module, nn.MaxPool2d):
module_add = nn.MaxPool2d(kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
return_indices=True)
modules.append(module_add)
else:
modules.append(module)
return modules
因为这是一个PyTorch模块(nn.Module),通过EncoderVGG实例实现小批量图像数据的前向传播需要一个forward方法:
def forward(self, x):
'''将图像输入encoder
Args:
x (Tensor): 图片tensor
Returns:
x_code (Tensor): 编码 tensor
pool_indices (list): 池索引张量
'''
pool_indices = []
x_current = x
for module_encode in self.encoder:
output = module_encode(x_current)
# 如果模块是池,有两个输出,第二个是池索引
if isinstance(output, tuple) and len(output) == 2:
x_current = output[0]
pool_indices.append(output[1])
else:
x_current = output
return x_current, pool_indices
该方法按顺序执行编码器中的每个层,并在创建池索引时收集它们。在执行编码器模块之后,代码与池索引的有序集合一起返回。
译码器模块的初始化:
class DecoderVGG(nn.Module):
'''译码器的代码基于vgg16体系结构与batch normalization。
Args:
encoder: ' EncoderVGG '的编码器实例,它将被转换成一个解码器
'''
channels_in = EncoderVGG.channels_code
channels_out = 3
def __init__(self, encoder):
super(DecoderVGG, self).__init__()
self.decoder = self._invert_(encoder)

def _invert_(self, encoder):
'''将编码器反转,以将译码器创建为编码器的镜像
译码器由两种主要类型组成:二维转置卷积和二维解池,2D卷积之后是批处理归一化和激活。
译码器是反向的,编码器中的卷积变成了转置卷积加上归一化和激活,编码器中的maxpooling变成了unpooling。
Args:
encoder (ModuleList): 编码器
Returns:
decoder (ModuleList): 通过编码器的“反转”获得的译码器
'''
modules_transpose = []
for module in reversed(encoder):
if isinstance(module, nn.Conv2d):
kwargs = {'in_channels' : module.out_channels, 'out_channels' : module.in_channels,
'kernel_size' : module.kernel_size, 'stride' : module.stride,
'padding' : module.padding}
module_transpose = nn.ConvTranspose2d(**kwargs)
module_norm = nn.BatchNorm2d(module.in_channels)
module_act = nn.ReLU(inplace=True)
modules_transpose += [module_transpose, module_norm, module_act]
elif isinstance(module, nn.MaxPool2d):
kwargs = {'kernel_size' : module.kernel_size, 'stride' : module.stride,
'padding' : module.padding}
module_transpose = nn.MaxUnpool2d(**kwargs)
modules_transpose += [module_transpose]
# 放弃最后的归一化和激活函数
modules_transpose = modules_transpose[:-2]
return nn.ModuleList(modules_transpose)
_invert_方法反向遍历编码器的各个层。
返回顶部
顶部