unet详解-凯发k8官方网
介绍
在图像分割中,机器必须将图像分割成不同的segments,每个segment代表不同的实体。
图像分割示例
正如你在上面看到的,图像如何变成两个部分,一个代表猫,另一个代表背景。图像分割在从自动驾驶汽车到卫星的许多领域都很有用。也许其中最重要的是医学影像。医学图像的微妙之处是相当复杂的。一台能够理解这些细微差别并识别出必要区域的机器,可以对医疗保健产生深远的影响。
卷积神经网络在简单的图像分割问题上取得了不错的效果,但在复杂的图像分割问题上却没有取得任何进展。这就是unet的作用。unet最初是专门为医学图像分割而设计的。该方法取得了良好的效果,并在以后的许多领域得到了应用。在本文中,我们将讨论unet工作的原因和方式
unet背后的直觉
卷积神经网络(cnn)背后的主要思想是学习图像的特征映射,并利用它进行更细致的特征映射。这在分类问题中很有效,因为图像被转换成一个向量,这个向量用于进一步的分类。但是在图像分割中,我们不仅需要将feature map转换成一个向量,还需要从这个向量重建图像。这是一项巨大的任务,因为要将向量转换成图像比反过来更困难。unet的整个理念都围绕着这个问题。
在将图像转换为向量的过程中,我们已经学习了图像的特征映射,为什么不使用相同的映射将其再次转换为图像呢?这就是unet背后的秘诀。用同样的 feature maps,将其用于contraction 来将矢量扩展成segmented image。这将保持图像的结构完整性,这将极大地减少失真。让我们更简单地理解架构。
unet架构
unet架构
该架构看起来像一个'u'。该体系结构由三部分组成:contraction,bottleneck和expansion 部分。contraction部分由许多contraction块组成。每个块接受一个输入,应用两个3x3的卷积层,然后是一个2x2的最大池化。在每个块之后,核或特征映射的数量会加倍,这样体系结构就可以有效地学习复杂的结构。最底层介于contraction层和expansion 层之间。它使用两个3x3 cnn层,然后是2x2 up convolution层。
这种架构的核心在于expansion 部分。与contraction层类似,它也包含几个expansion 块。每个块将输入传递到两个3x3 cnn层,然后是2x2上采样层。此外,卷积层使用的每个块的feature map数量得到一半,以保持对称性。每次输入也被相应的收缩层的 feature maps所附加。这个动作将确保在contracting 图像时学习到的特征将被用于重建图像。expansion 块的数量与contraction块的数量相同。之后,生成的映射通过另一个3x3 cnn层,feature map的数量等于所需的segment的数量。
unet中的损失计算
unet对每个像素使用了一种新颖的损失加权方案,使得分割对象的边缘具有更高的权重。这种损失加权方案帮助u-net模型以不连续的方式分割生物医学图像中的细胞,以便在binary segmentation map中容易识别单个细胞。
首先,在所得图像上应用pixel-wise softmax,然后是交叉熵损失函数。所以我们将每个像素分类为一个类。我们的想法是,即使在分割中,每个像素都必须存在于某个类别中,我们只需要确保它们可以。因此,我们只是将分段问题转换为多类分类问题,与传统的损失函数相比,它表现得非常好。
unet实现的python代码
python代码如下:
import torchfrom torch import nnimport torch.nn.functional as fimport torch.optim as optimclass unet(nn.module): def contracting_block(self, in_channels, out_channels, kernel_size=3): block = torch.nn.sequential( torch.nn.conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), torch.nn.relu(), torch.nn.batchnorm2d(out_channels), torch.nn.conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), torch.nn.relu(), torch.nn.batchnorm2d(out_channels), ) return block def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.sequential( torch.nn.conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.relu(), torch.nn.batchnorm2d(mid_channel), torch.nn.conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.relu(), torch.nn.batchnorm2d(mid_channel), torch.nn.convtranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) ) return block def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.sequential( torch.nn.conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.relu(), torch.nn.batchnorm2d(mid_channel), torch.nn.conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.relu(), torch.nn.batchnorm2d(mid_channel), torch.nn.conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), torch.nn.relu(), torch.nn.batchnorm2d(out_channels), ) return block def __init__(self, in_channel, out_channel): super(unet, self).__init__() #encode self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) self.conv_maxpool1 = torch.nn.maxpool2d(kernel_size=2) self.conv_encode2 = self.contracting_block(64, 128) self.conv_maxpool2 = torch.nn.maxpool2d(kernel_size=2) self.conv_encode3 = self.contracting_block(128, 256) self.conv_maxpool3 = torch.nn.maxpool2d(kernel_size=2) # bottleneck self.bottleneck = torch.nn.sequential( torch.nn.conv2d(kernel_size=3, in_channels=256, out_channels=512), torch.nn.relu(), torch.nn.batchnorm2d(512), torch.nn.conv2d(kernel_size=3, in_channels=512, out_channels=512), torch.nn.relu(), torch.nn.batchnorm2d(512), torch.nn.convtranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) ) # decode self.conv_decode3 = self.expansive_block(512, 256, 128) self.conv_decode2 = self.expansive_block(256, 128, 64) self.final_layer = self.final_block(128, 64, out_channel) def crop_and_concat(self, upsampled, bypass, crop=false): if crop: c = (bypass.size()[2] - upsampled.size()[2]) // 2 bypass = f.pad(bypass, (-c, -c, -c, -c)) return torch.cat((upsampled, bypass), 1) def forward(self, x): # encode encode_block1 = self.conv_encode1(x) encode_pool1 = self.conv_maxpool1(encode_block1) encode_block2 = self.conv_encode2(encode_pool1) encode_pool2 = self.conv_maxpool2(encode_block2) encode_block3 = self.conv_encode3(encode_pool2) encode_pool3 = self.conv_maxpool3(encode_block3) # bottleneck bottleneck1 = self.bottleneck(encode_pool3) # decode decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=true) cat_layer2 = self.conv_decode3(decode_block3) decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=true) cat_layer1 = self.conv_decode2(decode_block2) decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=true) final_layer = self.final_layer(decode_block1) return final_layer
以上python代码中的unet模块代表了unet的整体架构。使用contracaction_block和expansive_block分别创建contraction部分和expansion部分。crop_and_concat函数的作用是将contraction层的输出添加到新的expansion层输入中。训练部分的python代码可以写成
unet = unet(in_channel=1,out_channel=2)#out_channel represents number of segments desiredcriterion = torch.nn.crossentropyloss()optimizer = torch.optim.sgd(unet.parameters(), lr = 0.01, momentum=0.99)optimizer.zero_grad() outputs = unet(inputs)# permute such that number of desired segments would be on 4th dimensionoutputs = outputs.permute(0, 2, 3, 1)m = outputs.shape[0]# resizing the outputs and label to caculate pixel wise softmax lossoutputs = outputs.resize(m*width_out*height_out, 2)labels = labels.resize(m*width_out*height_out)loss = criterion(outputs, labels)loss.backward()optimizer.step()
结论
图像分割是一个重要的问题,每天都有一些新的研究论文发表。unet在这类研究中做出了重大贡献。许多新架构的灵感都来自unet。在业界,这种体系结构有很多变体,因此有必要理解第一个变体,以便更好地理解它们。
本文仅代表作者个人观点,不代表seo研究协会网官方发声,对观点有疑义请先联系作者本人进行修改,若内容非法请联系平台管理员,邮箱cxb5918@163.com。更多相关资讯,请到seo研究协会网www.seoxiehui.cn学习互联网营销技术请到巨推学院www.jutuiedu.com。
总结
以上是凯发k8官方网为你收集整理的unet详解_unet解释及python实现的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: python读取lmdb文件_pytho
- 下一篇: