基于数据增强的逐像素Q值估计离线强化学习方法和装置

专利2023-02-25  121


基于数据增强的逐像素q值估计离线强化学习方法和装置
技术领域
1.本技术涉及强化学习技术领域,特别是涉及一种基于数据增强的逐像素q 值估计离线强化学习方法和装置。


背景技术:

2.强化学习(rl)最近在自动驾驶、视频游戏和辅助医疗等诸多领域都取得重大进展。rl通过与环境的实时互动,在不断试错中收集最新的数据,迭代学习最优策略。然而,在线rl在真实环境中的应用却面临挑战:首先,智能体与真实环境进行试错交互可能会造成毁灭性后果,比如自动驾驶的车辆在道路上的探索行驶将对行人安全造成严重威胁;其次,通过与环境互动来积累和收集数据成本高昂且易产生资源浪费,制约了智能体频繁和无限制地收集数据;最后,智能体在与环境的交互中在线收集的数据的质量往往参差不齐,特别是大量无效探索产生的数据价值并不高,导致智能体学习过程中的数据利用率低。
3.离线rl提供了一种从现有数据集中学习,而无需与环境进一步互动的方法,使算法的训练避免高风险的试错,并极大节约计算资源。之前的离线rl算法,如行为克隆(bc),批量约束的深度q-learning(bcq)和保守q-learning (cql)等方法从先前收集的大型数据集中进行学习,且在训练期间不再更新数据集。这些算法有助于学习性能高效且数据利用率高的策略。尽管离线rl优势显著,但由于两个主要原因,在训练离线rl时,智能体很容易陷入次优策略: (a)数据集的质量和数量通常有限;(b)因分布外(ood)数据引发的外推误差增大。因此,现有技术存在对离线数据质量要求高、学习过程中数据利用率低的问题。


技术实现要素:

4.基于此,有必要针对上述技术问题,提供一种能够提升离线强化学习方法性能的基于数据增强的逐像素q值估计离线强化学习方法、装置、计算机设备和存储介质。
5.一种基于数据增强的逐像素q值估计离线强化学习方法,所述方法包括:
6.根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;
7.通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈,a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;
8.通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的
离线强化学习网络;
9.通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;
10.根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;
11.根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;
12.以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
13.在其中一个实施例中,还包括:获取训练好的学习网络作为专家策略或行为策略;
14.根据所述专家策略或行为策略,确定输入观测对应的具有最高值的动作及相应奖励;
15.由当前时刻图像堆栈s、对应的动作a,对应的奖励r及下一时刻的图像堆栈s

,构成一个离线数据元组(s,a,r,s

);
16.多次采样,得到大量离线数据元组,构成智能体视觉控制的离线数据集。
17.在其中一个实施例中,还包括:将随机的图像变换算法应用于所述小批量的原始输入观测得到变换图像;
18.在所述变换图像上使用双线性插值算法,给每边填充2个像素,得到填充图像;
19.在所述填充图像上进行随机裁剪,得到与原图像相同规格的数据增强后的图像。
20.在其中一个实施例中,还包括:通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值;
21.计算k个输入观测的平均q值为:
[0022][0023]
其中,i为所述原始输入观测的索引,θ为所述q值估计网络的参数,(si,ai) 为所述原始输入观测的状态-动作对,f(s,v)为通过调整v对所述原始输入观测进行变换的变换函数,其中,f保留了状态-动作对的q值不变,即 q
θ
(s,a)=q
θ
(f(s,ν),a);
[0024]
将所述平均q值q
θ
(si,ai)作为对应原始输入观测的预测q值。
[0025]
在其中一个实施例中,还包括:通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值;
[0026]
根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q 值为:
[0027][0028]
其中,θ

为q值目标网络的参数,ri为所述原始输入观测的奖励值,γ∈(0,1] 是折扣系数,λ为调节系数,为通过调整对所述原始输入观测进行变换的变换函数。
[0029]
在其中一个实施例中,还包括:根据所述预测q值和所述目标q值,以所述q值目标
网络为目标网络,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新:
[0030][0031]
其中,n为小批量的批量大小,α是学习率,表示q值估计的期望。
[0032]
在其中一个实施例中,还包括:根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新:
[0033]
θ'

τθ+(1-τ)θ'
[0034]
其中,τ为更新系数。
[0035]
一种基于数据增强的逐像素q值估计离线强化学习装置,所述装置包括:
[0036]
原始输入观测获取模块,用于根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;
[0037]
数据增强模块,用于通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch 为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈, a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;
[0038]
预测q值计算模块,用于通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的离线强化学习网络;
[0039]
目标q值计算模块,用于通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;
[0040]
q值估计网络参数更新模块,用于根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;
[0041]
q值目标网络参数更新模块,用于根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
[0042]
一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
[0043]
根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;
[0044]
通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈,a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆
叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;
[0045]
通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的离线强化学习网络;
[0046]
通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;
[0047]
根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;
[0048]
根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;
[0049]
以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
[0050]
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
[0051]
根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;
[0052]
通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈,a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;
[0053]
通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的离线强化学习网络;
[0054]
通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;
[0055]
根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;
[0056]
根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;
[0057]
以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
[0058]
上述基于数据增强的逐像素q值估计离线强化学习方法、装置、计算机设备和存储
介质,通过从所述离线数据集中采样小批量的原始输入观测,通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强,得到第一扩充样本和第二扩充样本,通过q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的q值,通过q 值目标网络得到m个输入观测的q值,根据m个输入观测的q值基于td-error 计算对应原始输入观测的q值目标,通过最小化损失函数对q值估计网络的参数进行更新,再根据更新后的q值估计网络的参数对所述q值目标网络的参数进行软更新,最终训练得到q值估计网络。本发明通过上述过程使用标准增量方法扩大训练数据集,通过正则化输入观测的q值函数,使得在数据集中的数据附近,但又不在数据集内的数据,具有与数据集中的观测值一致的q值,避免高估离线数据集中的静态数据,并显著提升了算法的泛化性;本发明将数据增强与基于像素观测的离线rl结合起来的方法,不需要对底层rl算法进行额外修改,使得该方法易于实现,并可扩展应用到其他算法,可扩展性强,实用性佳。
附图说明
[0059]
图1为一个实施例中基于数据增强的逐像素q值估计离线强化学习方法的流程示意图;
[0060]
图2为一个实施例中基于数据增强的逐像素q值估计离线强化学习方法的管道图;
[0061]
图3为一个实施例中基于数据增强的逐像素q值估计离线强化学习中使用的增强变换方法;
[0062]
图4为一个实施例中使用daq和bcq算法的特征抽取网络层处理后的 breakout游戏样本的t-sne结果,其中(a)为daq算法的breakout游戏样本经特征抽取网络层处理后的t-sne结果,(b)为breakout游戏样本经bcq算法特征抽取网络层处理后的t-sne结果;
[0063]
图5为一个实施例中基于数据增强的逐像素q值估计离线强化学习装置的结构框图;
[0064]
图6为一个实施例中计算机设备的内部结构图。
具体实施方式
[0065]
为了使本技术的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本技术进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本技术,并不用于限定本技术。
[0066]
在一个实施例中,如图1所示,提供了一种基于数据增强的逐像素q值估计离线强化学习方法,包括以下步骤:
[0067]
步骤102,根据预先获取的智能体视觉控制的离线数据集,从离线数据集中采样小批量的原始输入观测。
[0068]
每个图像堆栈中包含若干个连续帧堆叠的原始观测图像。这里之所以需要使用连续的若干帧堆叠,是因为只用1帧是不足以捕捉到真实的状态信息的,按照处理视觉控制任务的惯例,使用连续的3或4帧堆叠的图像以最大可能的捕捉当前状态的真实信息。举例,比如单帧是难以获取游戏画面或机器人场景中的位移、速度等信息,而连续的帧作为输入则可以使智能体捕捉到这些信息。
[0069]
在离线强化学习(offlinerl)中,从行为策略π
β
中采样数据集作为离线数据集,
表示离线数据集中一个小批量数据。离线强化学习中,智能体不能与环境交互。
[0070]
本发明提出了基于数据增强的逐像素q值估计离线强化学习方法(daq),逐像素的意思是智能体欲解决的控制任务的输入是pixels(像素或图像),相对应的则是state(向量或数组)。本发明方法是一种结合数据增强和q值估计的离线强化学习方法框架,通过对视觉控制任务的图像输入进行逐像素数据增强,并正则化q值估计,以提升离线强化学习算法泛化性和数据利用率。这是首次在离散动作空间的视觉控制任务上使用数据增强的离线强化学习算法。该方法充分利用了图像数据的旋转、平移等不变性的性质,使得在离散动作空间的视觉控制任务上使用离线强化学习算法时,有效扩大了像素级的图像输入的数据集。在像素级的离线rl上使用数据增强方法,最主要的困难是需克服分布外(out ofdistribution)数据的过估计问题,为此,我们提出使用正则化q值的方法,有效缓解了该问题。
[0071]
如图2所示,daq应用标准变换来增加从预先收集的数据集中采样的输入观测,并正则化输入观测值的q值函数。
[0072]
本实施例中,从离线数据集中采样小批量的原始输入观测,其中,原始输入观测为若干个连续图像的堆栈。
[0073]
步骤104,通过常见的图像变换算法对小批量的原始输入观测进行数据增强。
[0074]
其中,对当前时刻和下一时刻的每一图像堆栈分别进行k次和m次数据增强,分别得到第一扩充样本和第二扩充样本;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法的参数随机设置;第一扩充样本中包括k个输入观测,第二扩充样本中包括m个输入观测。
[0075]
值得注意的是,这里的数据增强操作只针对数据元组中的原始输入观测(即 s和s

),而不改变数据元组中的动作a和奖励r。
[0076]
考虑变换函数f:s
×
γ

s,这里γ是控制变换的参数集合。特别地,假设存在e∈γ,并且对于任何f(s,e)=s成立。本发明假设转换函数f保留了状态-动作对的q值不变,即:
[0077]qθ
(s,a)=q
θ
(f(s,ν),a)
ꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(1)
[0078]
对任何ν∈γ都成立,θ是q值函数的参数。
[0079]
从常见的图像变换中选择随机变换来调整从离线数据集中采样的输入观测,如随机翻转、裁剪和旋转等方法。首先将随机变换图像增强算法应用于若干个连续帧堆叠的原始观测图像,在一个批次的数据中随机应用数据增强方法,但在每个数据堆上使用统一的增强方法,这可以保持堆栈图像内的时域信息(比如速度、位移、旋转角度等)。训练数据通过在变换图像上使用双线性插值方法,给每边填充2个像素后,随机裁剪图像,最后生成在原始图像基础上变换
±
4像素的图像。数据增强的方法如图3所示。
[0080]
使用传统的数据增强方法使图像观测多样化,以有效估计静态数据集中数据的值,而无需修改底层离线rl算法。
[0081]
步骤106,通过预设的q值估计网络得到第一扩充样本中k个输入观测的 q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值。
[0082]
q值估计网络可以是任意使用q值估计的离线强化学习网络。
[0083]
具体地,对于采样的mini-batch中的任意状态-动作对(si,ai,ri,si′
),从γ中按照(增强数据包括原始状态)统一进行k次变换为了减少价值函数估计的方差,使用来估计q
θ
(si,ai)。即通过预设的q值估计网络得到第一扩充样本中k个输入观测的q值;
[0084]
计算k个输入观测的平均q值为:
[0085][0086]
其中,i为原始输入观测的索引,θ为q值估计网络的参数,(si,ai)为原始输入观测的状态-动作对,f(s,v)为通过变换v对原始输入观测进行变换的变换函数,其中,f保留了状态-动作对的q值不变,即q
θ
(s,a)=q
θ
(f(s,ν),a);
[0087]
将平均q值q
θ
(si,ai)作为对应原始输入观测的预测q值。
[0088]
在这里,由于假设增强状态处于原始状态的近似分布中,并且在相同的策略下它们应该具有相似的动作,因此这里不会对增强状态生成新的动作以节省计算。
[0089]
步骤108,通过预设的q值目标网络得到第二扩充样本中m个输入观测的 q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标 q值。
[0090]
q值目标网络与q值估计网络使用相同网络架构。
[0091]
具体地,当离线强化学习的算法为bcq时,通过q值目标网络得到第二扩充样本中m个输入观测的q值;
[0092]
根据m个输入观测的q值基于td-error计算对应原始输入观测的目标q 值为:
[0093][0094]
其中,θ

为q值目标网络的参数,ri为原始输入观测的奖励值,γ∈(0,1]是折扣系数,λ为调节系数,为通过调整对原始输入观测进行变换的变换函数。
[0095]
步骤110,根据预测q值和目标q值,通过最小化mse均方误差损失函数对q值估计网络的参数进行更新。
[0096]
具体地,根据预测q值和目标q值,以q值目标网络为目标网络,通过最小化损失函数对q值估计网络的参数进行更新:
[0097][0098]
其中,n为小批量的批量大小,α是学习率,表示q值估计的期望。
[0099]
步骤112,根据更新后的q值估计网络的参数对q值目标网络的参数进行软更新,直到达到预设时间步停止更新,以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
[0100]
具体地,根据更新后的q值估计网络的参数对q值目标网络的参数进行软更新:
[0101]
θ'

τθ+(1-τ)θ'
[0102]
其中,τ为调节系数。
[0103]
上述基于数据增强的逐像素q值估计离线强化学习方法中,通过从离线数据集中
采样小批量的原始输入观测,通过常见的图像变换算法对小批量的原始输入观测进行数据增强,得到第一扩充样本和第二扩充样本,通过q值估计网络得到第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的q值估计,通过q值目标网络得到第二扩充样本中m个输入观测的q值,根据m个输入观测的q值基于td-error计算对应原始输入观测的q值目标,通过最小化损失函数对q值估计网络的参数进行更新,再根据更新后的q值估计网络的参数对q值目标网络的参数进行软更新,最终训练得到q值估计网络。本发明通过上述过程使用标准增强方法扩大训练数据集,通过正则化输入观测的q值函数,使得在数据集中的数据附近,但又不在数据集内的数据,具有与数据集中的观测一致的q值,避免高估离线数据集中的静态数据,并显著提升了算法泛化性;本发明将数据增强与基于像素观测的离线rl 结合起来的方法,不需要对底层rl进行额外修改,使得该方法易于实现,并可扩展应用到其他算法,可扩展性强,实用性佳。
[0104]
在其中一个实施例中,还包括:获取训练好的学习网络作为专家策略或行为策略;根据专家策略或行为策略,确定输入观测对应的具有最高值的动作及相应奖励;由当前时刻图像堆栈s、对应的动作a,对应的奖励r及下一时刻的图像堆栈s

,构成一个离线数据元组(s,a,r,s

);多次采样,得到大量离线数据元组,构成智能体视觉控制的离线数据集。
[0105]
在其中一个实施例中,还包括:将随机的图像变换算法应用于小批量的原始输入观测得到变换图像;在变换图像上使用双线性插值算法,给每边填充2 个像素,得到填充图像;在填充图像上进行随机裁剪,得到与原图像相同规格的数据增强后的图像。
[0106]
之所以使用双线性插值而不是对图像边缘的像素进行简单的复制,是因为前一种方法将每个像素值替换为最近的2个像素值的平均值,这样可以获得更平滑的信息。
[0107]
在另一个实施例中,提供daq算法的伪代码如下:
[0108][0109]
在一个具体实施例中,使用bcq作为主干离线rl算法,并使用随机转换的数据增强方法,构成本实施例的daq算法。在atari游戏上评估daq算法,通过daq算法解决图像中的离
散控制任务,并在性能、泛化性和样本效率方面与bcq算法进行比较。
[0110]
为获得观测的全部信息,真实应用的输入观测是沿通道维度的4个连续图像的堆栈,其中像素为84
×
84,由环境渲染而成。选择图像转换作为图像变换的方法。对于q值估计网络和q值目标网络,变换的次数分别为[k=4,m=4]。在训练期间,使用100万个时间步训练的深度q学习网络dqn作为行为策略来收集100万个(s,a,r,s

,d)元组作为离线数据集。本发明的daq方法在整个训练过程中使用离线数据集来训练智能体。
[0111]
为了生成静态重放缓冲区,需要与环境交互以通过行为策略收集经验。考虑到观察输入必须是原始图像,选择atari游戏作为环境,并选择训练深度q 学习网络(dqn)作为行为策略以适应离散动作空间。经过100万步的训练,得到最终收敛到性能稳定的价值网络。使用训练好的网络选择每个观测对应的具有最高值的动作来生成数据元组,分别保存动作、观测和完成(游戏中一个 episode的终止标志),下一个观测可以从观测块中通过一些简单的操作来实现。为增加数据的多样性,在开始的一定时间步长内从动作空间中随机采样以选择随机动作,其余的时间步长则通过ε-greedy选择动作,ε从1线性退火到最后一步的0.01。
[0112]
评估表现:
[0113]
与离线强化学习方法比较:为了与bcq算法进行比较,daq算法的所有超参数都与bcq保持相同。为了可重复性,不对原始环境或奖励函数进行任何修改,并且按照惯例,使用策略步骤而不是真实环境步骤(真实环境步骤是策略步骤的4倍,因为每个动作重复执行4次)来实现性能。两种算法都训练了100万步,一个片段(episode)对应1000步,每步奖励在[0,1]范围内。 bcq和cql是应用于基于图像的输入观测的离散动作空间的原始算法的规范版本。离线dqn是在离线数据集上训练的dqn智能体,无需与环境交互。缓冲性能(buffer performance)是行为策略生成的数据集的性能。每个评估是5 个片段上的性能的平均。结果表明,本发明的daq算法在8个任务上优于或匹配最先进的离线rl方法,并且在3个任务(demonattack、jamesbond和 seaquest)上获得了较大的收益。
[0114]
提升的泛化能力:为了验证daq方法的泛化能力,应用daq和bcq算法的训练模型分别处理原始数据和增强数据,以获得每个输入图像堆栈的特征。首先,从数据集中采样20个图像堆栈作为原始输入观察值;其次,通过随机变换对每个堆栈进行4次扩充,生成80个图像堆栈;最后,使用先前训练模型的图像处理层(包括bcq和daq模型中的处理层)来提取观察结果的内在特征。使用t-sne来呈现特征的集群,以比较两种方法的泛化能力。图4使用经daq和bcq算法的特征抽取层分别处理的breakout游戏样本的 t-sne结果。图中使用的数据是从批量转换的4个原始采样输入观测值扩展而来的。除了原来的4个数据外,其他数据都是通过转换增强生成的。该图显示,即使有很多未见过的图像,daq(图a)可以更好地对它们进行聚类,这意味着daq在泛化方面优于bcq(图b)。图表明daq算法比特征图中的bcq 更好地对增强图像堆栈进行聚类。这表明数据增强使网络能够学习更多概括性的表示。
[0115]
消融实验:
[0116]
数据有效性:与在线强化学习通过与环境交互步数来反映样本效率不同,离线强化学习通过在有限数据训练下比较性能来达到评价数据有效性的目的。为了评估daq的数据效率,选择使用10%和50%的数据集来训练模型,然后评估模型性能来比较数据效率。在表1中的实验表明,在许多情况下,在 50%数据集上进行训练时,daq可以与bcq的样本效率
相媲美。同时,可以发现在使用10%数据集时,bcq在大多数任务上都优于daq。
[0117]
表1使用10%和50%数据集的不同atari游戏的平均片段回报
[0118][0119]
应该理解的,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
[0120]
在一个实施例中,如图5所示,提供了一种基于数据增强的逐像素q值估计离线强化学习装置,包括:原始输入观测获取模块502、数据增强模块504、预测q值计算模块506、目标q值计算模块508、q值估计网络参数更新模块 510和q值目标网络参数更新模块512,其中:
[0121]
原始输入观测获取模块502,用于根据预先获取的智能体视觉控制的离线数据集,从离线数据集中采样小批量的原始输入观测;
[0122]
数据增强模块504,用于通过常见的图像变换算法对小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch 为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈, a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;第一扩充样本中包括k个输入观测,第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;
[0123]
预测q值计算模块506,用于通过预设的q值估计网络得到第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;q值估计网络为任意使用q值估计的离线强化学习网络;
[0124]
目标q值计算模块508,用于通过预设的q值目标网络得到第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;q值目标网络与q值估计网络使用相同网络架构;
[0125]
q值估计网络参数更新模块510,用于根据预测q值和目标q值,通过最小化mse均方误差损失函数对q值估计网络的参数进行更新;
[0126]
q值目标网络参数更新模块512,用于根据更新后的q值估计网络的参数对q值目标网络的参数进行软更新,直到达到预设时间步停止更新;以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。
[0127]
原始输入观测获取模块502还用于获取训练好的学习网络作为专家策略或行为策略;根据专家策略或行为策略,确定输入观测对应的具有最高值的动作及相应奖励;由当前时刻图像堆栈s、对应的动作a,对应的奖励r及下一时刻的图像堆栈s

,构成一个离线数据元组(s,a,r,s

);多次采样,得到大量离线数据元组,构成智能体视觉控制的离线数据集。
[0128]
数据增强模块504还用于将随机的图像变换算法应用于小批量的原始输入观测得到变换图像;在变换图像上使用双线性插值算法,给每边填充2个像素,得到填充图像;在填充图像上进行随机裁剪,得到与原图像相同规格的数据增强后的图像。
[0129]
预测q值计算模块506还用于通过预设的q值估计网络得到第一扩充样本中k个输入观测的q值;计算k个输入观测的平均q值为:
[0130][0131]
其中,i为原始输入观测的索引,θ为q值估计网络的参数,(si,ai)为原始输入观测的状态-动作对,f(s,v)为通过调整v对原始输入观测进行变换的变换函数,其中,f保留了状态-动作对的q值不变,即q
θ
(s,a)=q
θ
(f(s,ν),a);将平均q 值q
θ
(si,ai)作为对应原始输入观测的预测q值。
[0132]
目标q值计算模块508还用于通过预设的q值目标网络得到第二扩充样本中m个输入观测的q值;根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值为:
[0133][0134]
其中,θ

为q值目标网络的参数,ri为原始输入观测的奖励值,γ∈(0,1]是折扣系数,λ为调节系数,为通过调整对原始输入观测进行变换的变换函数。
[0135]
q值估计网络参数更新模块510还用于根据预测q值和目标q值,以q值目标网络为目标网络,通过最小化mse均方误差损失函数对q值估计网络的参数进行更新:
[0136][0137]
其中,n为小批量的批量大小,α是学习率,表示q值估计的期望。
[0138]
q值目标网络参数更新模块512还用于根据更新后的q值估计网络的参数对q值目标网络的参数进行软更新:
[0139]
θ'

τθ+(1-τ)θ'
[0140]
其中,τ为更新系数。
[0141]
关于基于数据增强的逐像素q值估计离线强化学习装置的具体限定可以参见上文中对于基于数据增强的逐像素q值估计离线强化学习方法的限定,在此不再赘述。上述基于数据增强的逐像素q值估计离线强化学习装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
[0142]
在一个实施例中,提供了一种计算机设备,该计算机设备可以是终端,其内部结构图可以如图6所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口、显示屏和输入装置。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和计算机程序。该内存储器为非易失性存储介质中的操作系统和计算机程序的运行提供环境。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现一种基于数据增强的逐像素q值估计离线强化学习方法。该计算机设备的显示屏可以是液晶显示屏或者电子墨水显示屏,该计算机设备的输入装置可以是显示屏上覆盖的触摸层,也可以是计算机设备外壳上设置的按键、轨迹球或触控板,还可以是外接的键盘、触控板或鼠标等。
[0143]
本领域技术人员可以理解,图6中示出的结构,仅仅是与本技术方案相关的部分结构的框图,并不构成对本技术方案所应用于其上的计算机设备的限定,具体的计算机设备可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
[0144]
在一个实施例中,提供了一种计算机设备,包括存储器和处理器,该存储器存储有计算机程序,该处理器执行计算机程序时实现上述方法实施例中的步骤。
[0145]
在一个实施例中,提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现上述方法实施例中的步骤。
[0146]
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本技术所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(rom)、可编程rom(prom)、电可编程 rom(eprom)、电可擦除可编程rom(eeprom)或闪存。易失性存储器可包括随机存取存储器(ram)或者外部高速缓冲存储器。作为说明而非局限, ram以多种形式可得,诸如静态ram(sram)、动态ram(dram)、同步 dram(sdram)、双数据率sdram(ddrsdram)、增强型sdram (esdram)、同步链路(synchlink)dram(sldram)、存储器总线(rambus) 直接ram(rdram)、直接存储器总线动态ram(drdram)、以及存储器总线动态ram(rdram)等。
[0147]
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
[0148]
以上所述实施例仅表达了本技术的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本技术构思的前提下,还可以做出若干变形和改进,这些都属于本技术的保护
范围。因此,本技术专利的保护范围应以所附权利要求为准。

技术特征:
1.一种基于数据增强的逐像素q值估计离线强化学习方法,其特征在于,所述方法包括:根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈,a表示当前时刻智能体采取的动作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的离线强化学习网络;通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。2.根据权利要求1所述的方法,其特征在于,根据预先获取的智能体视觉控制的离线数据集,包括:获取训练好的学习网络作为专家策略或行为策略;根据所述专家策略或行为策略,确定输入观测对应的具有最高值的动作及相应奖励;由当前时刻图像堆栈s、对应的动作a,对应的奖励r及下一时刻的图像堆栈s

,构成一个离线数据元组(s,a,r,s

);多次采样,得到大量离线数据元组,构成智能体视觉控制的离线数据集。3.根据权利要求1所述的方法,其特征在于,通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强,包括:将随机的图像变换算法应用于所述小批量的原始输入观测得到变换图像;在所述变换图像上使用双线性插值算法,给每边填充2个像素,得到填充图像;在所述填充图像上进行随机裁剪,得到与原图像相同规格的数据增强后的图像。4.根据权利要求1所述的方法,其特征在于,通过预设q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值,包括:通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值;
计算k个输入观测的平均q值为:其中,i为所述原始输入观测的索引,θ为所述q值估计网络的参数,(s
i
,a
i
)为所述原始输入观测的状态-动作对,f(s,v)为通过调整v对所述原始输入观测进行变换的变换函数,其中,f保留了状态-动作对的q值不变,即q
θ
(s,a)=q
θ
(f(s,ν),a);将所述平均q值q
θ
(s
i
,a
i
)作为对应原始输入观测的预测q值。5.根据权利要求4所述的方法,其特征在于,通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值,包括:通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值;根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值为:其中,θ

为q值目标网络的参数,r
i
为所述原始输入观测的奖励值,γ∈(0,1]是折扣系数,λ为调节系数,为通过调整对所述原始输入观测进行变换的变换函数。6.根据权利要求5所述的方法,其特征在于,根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新,包括:根据所述预测q值和所述目标q值,以所述q值目标网络为目标网络,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新:其中,n为小批量的批量大小,α是学习率,表示q值估计的期望。7.根据权利要求6所述的方法,其特征在于,根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,包括:根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新:θ'

τθ+(1-τ)θ'其中,τ为更新系数。8.一种基于数据增强的逐像素q值估计离线强化学习装置,其特征在于,所述装置包括:原始输入观测获取模块,用于根据预先获取的智能体视觉控制的离线数据集,从所述离线数据集中采样小批量的原始输入观测;数据增强模块,用于通过常见的图像变换算法对所述小批量的原始输入观测进行数据增强;每一小批量的数据由batch个数据组(s,a,r,s

)组成,其中batch为小批量数据的数量,s代表当前时刻图像堆栈,s

代表下一时刻的图像堆栈,a表示当前时刻智能体采取的动
作,r表示智能体在当前时刻s采取动作a后得到的环境奖励反馈;每个所述图像堆栈中包含若干个连续帧堆叠的原始观测图像;图像堆栈s和s

分别进行k次和m次数据增强后,分别得到第一扩充样本和第二扩充样本;所述第一扩充样本中包括k个输入观测,所述第二扩充样本中包括m个输入观测;每一图像堆栈内的图像变换算法的参数设置一致,各图像堆栈的图像变换算法参数随机设置;预测q值计算模块,用于通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值,将k个输入观测的平均q值作为对应原始输入观测的预测q值;所述q值估计网络为任意使用q值估计的离线强化学习网络;目标q值计算模块,用于通过预设的q值目标网络得到所述第二扩充样本中m个输入观测的q值,根据m个输入观测的q值,基于td-error计算对应原始输入观测的目标q值;所述q值目标网络与所述q值估计网络使用相同网络架构;q值估计网络参数更新模块,用于根据所述预测q值和所述目标q值,通过最小化mse均方误差损失函数对所述q值估计网络的参数进行更新;q值目标网络参数更新模块,用于根据更新后的所述q值估计网络的参数对所述q值目标网络的参数进行软更新,直到达到预设时间步停止更新;以完成更新的q值估计网络为最终学习到的q值估计强化学习网络。9.根据权利要求8所述的装置,其特征在于,所述数据增强模块还用于:将随机的图像变换算法应用于所述小批量的原始输入观测得到变换图像;在所述变换图像上使用双线性插值算法,给每边填充2个像素,得到填充图像;在所述填充图像上进行随机裁剪,得到与原图像相同规格的数据增强后的图像。10.根据权利要求8所述的装置,其特征在于,所述预测q值计算模块还用于:通过预设的q值估计网络得到所述第一扩充样本中k个输入观测的q值;计算k个输入观测的平均q值为:其中,i为所述原始输入观测的索引,θ为所述q值估计网络的价值函数的参数,(s
i
,a
i
)为所述原始输入观测的状态-动作对,f(s,v)为通过调整v对所述原始输入观测进行变换的变换函数,其中,f保留了状态-动作对的q值不变,即q
θ
(s,a)=q
θ
(f(s,ν),a);将所述平均q值q
θ
(s
i
,a
i
)作为对应原始输入观测的预测q值。

技术总结
本申请涉及一种基于数据增强的逐像素Q值估计离线强化学习方法和装置。所述方法包括:通过从离线数据集中采样小批量的原始输入观测,通过常见的图像变换算法对小批量的原始输入观测进行数据增强,并对输入观测的Q值进行正则化处理,最终训练得到用于决策的Q值网络。本发明通过使用标准增量方法扩大训练数据集,通过正则化输入观测的Q值避免高估离线数据集数据分布附近的数据,并显著提升了算法泛化性;将数据增强与基于像素观测的离线RL算法结合起来的方法,不需要对底层RL算法进行额外修改,使得该方法易于实现,并可扩展应用到其他基于Q值估计的离线RL算法,可扩展性强,实用性佳。佳。佳。


技术研发人员:张龙飞 冯旸赫 张驭龙 刘忠 黄金才 程光权 陈丽 梁星星 吴克宇 阳方杰
受保护的技术使用者:中国人民解放军国防科技大学
技术研发日:2022.07.15
技术公布日:2022/11/1
转载请注明原文地址: https://tieba.8miu.com/read-1338.html

最新回复(0)