1.本发明涉及小样本学习、元学习领域,具体地说,涉及一种基于特征金字塔和特征融合的小样本图像分类方法及系统。
背景技术:2.深度神经网络模型往往需要大量有标注数据的训练样本,才能达到较优的训练效果。而在现实中,样本的标签往往需要耗费大量的人力物力,或者某些情况下能够使用的样本数据就很少,这时若直接将少量样本用于训练,就会产生过拟合的问题,小样本学习正是为了解决这类问题而产生的。
3.将小样本学习的基本模型定义为p=c(f(x|θ)|w),其中特征提取器可以表示为f,分类器可以表示为c,x表示待识别的输入图像,θ表示特征提取器f的参数,w表示分类器c的参数,p表示模型输出的预测结果。在小样本学习的过程中,由于样本数量少,直接训练会导致模型参数θ和w过拟合,在目标任务上精度下降。
4.通过把大量可用数据的相似先前任务的训练集定义为d
base
,把包含目标检测任务的小样本学习数据集定义为d
novel
。将模型在d
base
上进行训练,学习到一个较优的参数θ和w。模型通过已经得到的初始化参数在d
novel
上进行训练,得到新的模型参数θ1和w1,更新原有参数,更新后的新模型p=c(f(x|θ1)|w1)能够较为准确的完成图像分类任务。
5.围绕样本数量少这一核心问题,现有的小样本学习策略中,主要通过基于数据增强的方法、基于度量学习的方法、基于模型的方法和基于参数优化的方法解决。度量学习也叫作相似度学习,度量学习的任务目标是学习一个成对的相似性度量s(
·
,
·
),类似的样本相似性分值更高,不类似的样本相似性分值较低。其中s既可以是一个无需学习的距离度量,也可以是一个可以进行学习的神经网络,其输出的相似性分值可以查询测试集样本分类。然而,现有的基于度量学习的小样本学习,只关注模型最后输出之间的相似性和距离度量,未关注模型网络中间层的相似性和度量,识别准确率较低,影响模型最终的分类准确性。
技术实现要素:6.本发明的目的在于克服现有技术小样本图像分类准确率较低的问题,提供一种基于特征金字塔和特征融合的小样本图像分类方法,通过新的特征金字塔关系网络和新的特征融合方式,提高在小样本图像分类中的准确性。
7.为了实现上述目的,本发明所采用的技术方案如下:
8.一种基于特征金字塔和特征融合的小样本图像分类方法,包括以下步骤:
9.s1.构建多层神经网络的特征金字塔关系网络模型,每层神经网络包括特征提取模块、关系模块和特征融合模块;
10.s2.获取数据集并对数据集进行扩充,将经过扩充后的数据集划分为训练集、验证集、测试集;
11.s3.采用c-way k-shot方式训练特征金字塔关系网络模型,每次训练分别从训练集采样支撑集与查询集;
12.s4.输入支撑集图像和查询集图像,特征提取模块提取图像的特征,并输出图像的特征向量,特征融合模块融合支撑集图像和查询集图像的特征向量;
13.s5.将融合后的特征向量输入关系模块,关系模块输出支撑集图像和查询集图像的相似分值,将所有关系模块输出的相似分值进行处理,获得最终的相似分值;
14.s6.计算特征金字塔关系网络模型的损失,并更新特征金字塔关系网络模型的参数,重复迭代训练,直至损失的误差值趋于稳定;
15.s7.保存训练好的特征金字塔关系网络模型,将特征金字塔关系网络模型用于小样本图像分类测试。
16.进一步,特征融合模块包括特征融合项,特征融合项为:
17.c
′
(fs,fq)=concate(fs,fq,mul(fs,fq))
18.式中,fs表示查询集图像的特征向量,fq表示支撑集图像的特征向量,concate(
·
,
·
)表示在特征通道进行拼接的操作,mul(
·
,
·
)运算表示将特征图按位置对应元素相乘。
19.进一步,步骤s6中,将一组图像的相似性分值视为回归任务,采用均方误差mse函数作为每层神经网络的损失函数,均方误差mse函数为:
20.mse(r,ys,yq)=(r-1(ys==yq))221.式中,r表示每层神经网络输出的相似性分数,ys表示支撑集图像的标签,yq表示查询集图像的标签。
22.进一步,步骤s6中,利用损失函数计算特征金字塔关系网络模型的损失,损失函数为:
[0023][0024]
式中,r
l
表示第l层神经网络输出的相似性分值,ys表示支撑集图像的标签,yq表示查询集图像的标签,mse表示均方误差函数,n表示神经网络的层数。
[0025]
进一步,步骤s2中,将获取的数据集通过旋转来进行扩充,旋转角度为90度、180度、270度。
[0026]
进一步,特征金字塔关系网络模型中,关系模块最后的全连接层的激活函数使用sigmoid函数,其他所有的激活函数都使用relu函数。
[0027]
一种基于特征金字塔和特征融合的小样本图像分类系统,包括:
[0028]
特征提取模块,用于提取输入图像的特征;
[0029]
特征融合模块,用于将输入图像的特征进行融合;
[0030]
关系模块,用于判别输入的支撑集图像特征和查询集图像特征的相似度。
[0031]
进一步,特征提取模块包括四个卷积块和两个2*2的最大池化层,卷积块、最大池化层、卷积块、最大池化层、卷积块、卷积块依次连接。
[0032]
进一步,关系模块包括两个卷积块、两个2*2的最大池化层、relu全连接层和sigmoid全连接层,卷积块、最大池化层、卷积块、最大池化层、relu全连接层、sigmoid全连接层依次连接。
[0033]
进一步,卷积块包括卷积层、batch norm层和relu激活函数层,卷积层的卷积核大小为3*3,输出通道数为64。
[0034]
与现有技术相比,本发明通过构建特征金字塔关系网络(fprn)模型,提高了小样本图像分类的精度,且由于特征金字塔关系网络(fprn)模型本身体量较小,通过特征金字塔关系网络(fprn)模型仍可快速获得检测结果,准确率高。
附图说明
[0035]
图1为特征金字塔关系网络fprn的结构示意图。
[0036]
图2为特征提取模块的结构示意图。
[0037]
图3为关系模块的结构示意图。
[0038]
图4为卷积块conv block的结构示意图。
具体实施方式
[0039]
下面结合附图和具体实施例对本发明基于特征金字塔和特征融合的小样本图像分类方法及系统作进一步说明。
[0040]
本发明公开了一种基于特征金字塔和特征融合的小样本图像分类方法,基于特征金字塔和特征融合的小样本图像分类方法包括以下步骤;
[0041]
s1.构建多层神经网络的特征金字塔关系网络模型,每层神经网络包括特征提取模块、关系模块和特征融合模块。
[0042]
s2.获取数据集并对数据集进行扩充,将经过扩充后的数据集划分为训练集、验证集、测试集。
[0043]
s3.采用c-way k-shot方式训练特征金字塔关系网络模型,每次训练分别从训练集采样支撑集与查询集。
[0044]
s4.输入支撑集图像和查询集图像,特征提取模块提取图像的特征,并输出图像的特征向量,特征融合模块融合支撑集图像和查询集图像的特征向量。
[0045]
s5.将融合后的特征向量输入关系模块,关系模块输出支撑集图像和查询集图像的相似分值,将所有关系模块输出的相似分值进行处理,获得最终的相似分值。
[0046]
s6.计算特征金字塔关系网络模型的损失,并更新特征金字塔关系网络模型的参数,重复迭代训练,直至损失的误差值趋于稳定。
[0047]
s7.保存训练好的特征金字塔关系网络模型,将特征金字塔关系网络模型用于小样本图像分类测试。
[0048]
请参阅图1,本发明还公开了一种基于特征金字塔和特征融合的小样本图像分类系统,基于特征金字塔和特征融合的小样本图像分类系统包括特征提取模块、特征融合模块和关系模块,特征提取模块用于提取输入图像的特征,特征融合模块用于将输入图像的特征进行融合,关系模块用于判别输入的支撑集图像特征和查询集图像特征的相似度。
[0049]
具体地,在神经网络中,神经网络层级越深,感受野越大,越关注图像的整体特征,神经网络的层级越浅,感受野越小,越关注图像的局部特征。例如,对动物进行分类,深层神经网络可以判别具体动物种类特征,浅层神经网络可以提取到毛发特征、背景纹理特征等,因此,浅层神经网络的特征也可以用来帮助判别动物种类。基于此,本发明提出了一种特征
[0065]
式中,r表示每层神经网络输出的相似性分数,ys表示支撑集图像的标签,yq表示查询集图像的标签。当标签相同时,(ys==yq)的值为1,当标签不同时,(ys==yq)的值为0。
[0066]
在特征金字塔关系网络(fprn)模型中,每层神经网络的关系模块都会输出一个相似性分值,因此特征金字塔关系网络(fprn)模型总的损失函数为:
[0067][0068]
式中,r
l
表示第l层输出的相似性分值,ys表示支撑集图像的标签,yq表示查询集图像的标签,mse表示均方误差函数。
[0069]
通过损失函数计算特征金字塔关系网络(fprn)模型的损失,并反向传播更新特征金字塔关系网络(fprn)模型的参数。重复迭代训练,直至损失函数计算得到的损失的误差值趋于稳定。
[0070]
保存训练好的特征金字塔关系网络模型,将特征金字塔关系网络模型用于小样本图像分类测试。本发明提出的基于特征金字塔和特征融合的小样本图像分类方法,在两个公开数据集上取得了较好的检测效果。
[0071]
omniglot数据集包含50种不同语言的1623种字符类,每个字符类包含20个由不同的人书写的样本。在训练过程中,20-way 1-shot每次训练,由每类别1张支撑集图像和10张查询集图像组成,20-way 5-shot每次训练,由每类别5张支撑集图像和5张查询集图像组成。在测试过程中,本发明在测试集中随机采样1000次来评估特征金字塔关系网络(fprn)模型的分类结果,其中1-shot每次采样1张测试集图像,5-shot每次采样5张测试集图像。
[0072]
miniimagenet数据集包含100个类别的共计60000张彩色图像组成,每个类别均包含600张样本,本发明取64个类别用以训练、16个类别用以验证、20个类别用以测试。在miniimagenet数据集上,本发明采用了5-way 1-shot和5-way 5-shot的设置。在训练过程中,5-way 1-shot的每次训练,由每类别1张支撑集图像和15张查询集图像组成,5-way 5-shot的每次训练,由每类别5张支撑集图像和10张查询集图像组成。在测试过程中,本发明在测试中随机采样600次来评估特征金字塔关系网络(fprn)模型的分类结果,其中5-way 1-shot和5-way 5-shot的设置中,每次均采样15张测试集图像。
[0073]
本发明比较了特征金字塔关系网络(fprn)模型与其他流行的基于度量学习的小样本学习模型的图像分类方法的结果。用于进行对比的模型主要包括孪生网络(siamese network)、原型网络(prototype network)、匹配网络(matching networks)和关系网络(relation network)。特征金字塔关系网络模型(fprn)与这些模型基准在omniglo数据集上的比较结果如表1所示。
[0074]
表1 omniglot数据集实验结果
[0075][0076]
特征金字塔关系网络(fprn)模型与孪生网络(siamese network)、原型网络(prototype network)、匹配网络(matching networks)和关系网络(relation network)这些模型基准在miniimagenet数据集上的比较结果如表2所示。
[0077]
表2 miniimagenet数据集实验结果
[0078][0079]
如表1和表2所示,通过实验数据表明,本发明提出的特征金字塔关系网络(fprn)模型,在各项实验中都取得了最高的判断准确率。在ominiglot数据集上,本发明提出的特征金字塔关系网络(fprn)模型,在20-way 1-shot的设置中,能够达到98.3%的分类准确率,在20-way 5-shot的设置中,能够达到99.2%的分类准确率。在miniimagenet数据集上,本发明提出的特征金字塔关系网络(fprn)模型,在5-way 1-shot的设置中,能够达到50.2%的分类准确率,在5-way 5-shot的设置中,能够达到66.7%的分类准确率。
[0080]
本发明在miniimagenet数据集上5-way 1-shot设置中,比较了关系网络模型和特征金字塔关系网络(fprn)模型的检测速度。本发明实验所用显卡为nvidia quadro p2000显卡,使用关系网络判别速度为17.1fps,使用特征金字塔关系网络模型(fprn)的判别速度为16.3fps,特征金字塔关系网络(fprn)模型比关系网络模型判别速度慢4.7%。在该实验设置中,特征金字塔关系网络(fprn)模型的检测准确率为50.2%,关系网络模型的检测准确率为47.3%,特征金字塔关系网络(fprn)模型比关系网络模型的检测准确率绝对值提升2.9%,按照百分比计算准确率提升为6.1%。特征金字塔关系网络(fprn)模型在牺牲4.7%的检测时间条件下,获得了6.1%的百分比准确率提升。
[0081]
综上所述,本发明通过构建特征金字塔关系网络(fprn)模型,提高了小样本图像
分类的精度,且由于特征金字塔关系网络(fprn)模型本身体量较小,通过特征金字塔关系网络(fprn)模型仍可快速获得检测结果,准确率高。
[0082]
上述说明是针对本发明较佳可行实施例的详细说明,但实施例并非用以限定本发明的专利申请范围,凡本发明所揭示的技术精神下所完成的同等变化或修饰变更,均应属于本发明所涵盖专利范围。
技术特征:1.一种基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,包括以下步骤:s1.构建多层神经网络的特征金字塔关系网络模型,每层神经网络包括特征提取模块、关系模块和特征融合模块;s2.获取数据集并对数据集进行扩充,将经过扩充后的数据集划分为训练集、验证集、测试集;s3.采用c-way k-shot方式训练特征金字塔关系网络模型,每次训练分别从训练集采样支撑集与查询集;s4.输入支撑集图像和查询集图像,特征提取模块提取图像的特征,并输出图像的特征向量,特征融合模块融合支撑集图像和查询集图像的特征向量;s5.将融合后的特征向量输入关系模块,关系模块输出支撑集图像和查询集图像的相似分值,将所有关系模块输出的相似分值进行处理,获得最终的相似分值;s6.计算特征金字塔关系网络模型的损失,并更新特征金字塔关系网络模型的参数,重复迭代训练,直至损失的误差值趋于稳定;s7.保存训练好的特征金字塔关系网络模型,将特征金字塔关系网络模型用于小样本图像分类测试。2.如权利要求1的基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,特征融合模块包括特征融合项,特征融合项为:c
′
(f
s
,f
q
)=concate(f
s
,f
q
,mul(f
s
,f
q
))式中,f
s
表示查询集图像的特征向量,f
q
表示支撑集图像的特征向量,concate(
·
,
·
)表示在特征通道进行拼接的操作,mul(
·
,
·
)运算表示将特征图按位置对应元素相乘。3.如权利要求1的基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,步骤s6中,将一组图像的相似性分值视为回归任务,采用均方误差mse函数作为每层神经网络的损失函数,均方误差mse函数为:mse(r,y
s
,y
q
)=(r-1(y
s
==y
q
))2式中,r表示每层神经网络输出的相似性分数,y
s
表示支撑集图像的标签,y
q
表示查询集图像的标签。4.如权利要求3的基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,步骤s6中,利用损失函数计算特征金字塔关系网络模型的损失,损失函数为:式中,r
l
表示第l层神经网络输出的相似性分值,y
s
表示支撑集图像的标签,y
q
表示查询集图像的标签,mse表示均方误差函数,n表示神经网络的层数。5.如权利要求1的基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,步骤s2中,将获取的数据集通过旋转来进行扩充,旋转角度为90度、180度、270度。6.如权利要求1的基于特征金字塔和特征融合的小样本图像分类方法,其特征在于,特征金字塔关系网络模型中,关系模块最后的全连接层的激活函数使用sigmoid函数,其他所有的激活函数都使用relu函数。
7.一种基于特征金字塔和特征融合的小样本图像分类系统,其特征在于,包括:特征提取模块,用于提取输入图像的特征;特征融合模块,用于将输入图像的特征进行融合;关系模块,用于判别输入的支撑集图像特征和查询集图像特征的相似度。8.如权利要求7的基于特征金字塔和特征融合的小样本图像分类系统,其特征在于,特征提取模块包括四个卷积块和两个2*2的最大池化层,卷积块、最大池化层、卷积块、最大池化层、卷积块、卷积块依次连接。9.如权利要求7的基于特征金字塔和特征融合的小样本图像分类系统,其特征在于,关系模块包括两个卷积块、两个2*2的最大池化层、relu全连接层和sigmoid全连接层,卷积块、最大池化层、卷积块、最大池化层、relu全连接层、sigmoid全连接层依次连接。10.如权利要求8或9的基于特征金字塔和特征融合的小样本图像分类系统,其特征在于,卷积块包括卷积层、batch norm层和relu激活函数层,卷积层的卷积核大小为3*3,输出通道数为64。
技术总结本发明公开了一种基于特征金字塔和特征融合的小样本图像分类方法,包括以下步骤:S1.构建特征金字塔关系网络模型,模型包括特征提取模块、关系模块和特征融合模块;S2.对数据集进行扩充,将数据集划分为训练集、验证集、测试集;S3.训练模型,从训练集中采样支撑集与查询集;S4.输入支撑集图像和查询集图像,特征提取模块提取图像的特征,并输出图像的特征向量,特征融合模块融合特征向量;S5.将特征向量输入关系模块,关系模块输出支撑集图像和查询集图像的相似分值,将所有相似分值进行处理,获得最终的相似分值;S6.计算模型的损失,并更新模型的参数,重复迭代训练,直至损失的误差值趋于稳定;S7.保存训练好的模型,将模型用于小样本图像分类测试。样本图像分类测试。样本图像分类测试。
技术研发人员:王先知 许洁斌 艾浩然
受保护的技术使用者:华南理工大学
技术研发日:2022.06.27
技术公布日:2022/11/1