pixel-based RL 算法逆襲,BAIR 提出將對(duì)比學(xué)習(xí)與 RL 相結(jié)合的算法,其 sample-efficiency 匹敵 state-based RL。此次研究的本質(zhì)在于回答一個(gè)問(wèn)
pixel-based RL 算法逆襲,BAIR 提出將對(duì)比學(xué)習(xí)與 RL 相結(jié)合的算法,其 sample-efficiency 匹敵 state-based RL。
此次研究的本質(zhì)在于回答一個(gè)問(wèn)題—使用圖像作為觀測(cè)值(pixel-based)的 RL 是否能夠和以坐標(biāo)狀態(tài)作為觀測(cè)值的 RL 一樣有效?傳統(tǒng)意義上,大家普遍認(rèn)為以圖像為觀測(cè)值的 RL 數(shù)據(jù)效率較低,通常需要一億個(gè)交互的 step 來(lái)解決 Atari 游戲那樣的基準(zhǔn)測(cè)試任務(wù)。
研究人員介紹了 CURL:一種用于強(qiáng)化學(xué)習(xí)的無(wú)監(jiān)督對(duì)比表征。CURL 使用對(duì)比學(xué)習(xí)的方式從原始像素中提取高階特征,并在提取的特征之上執(zhí)行異策略控制。在 DeepMind Control Suite 和 Atari Games 中的復(fù)雜任務(wù)上,CURL 優(yōu)于以前的 pixel-based 的方法(包括 model-based 和 model-free),在 100K 交互步驟基準(zhǔn)測(cè)試中,其性能分別提高了 2.8 倍以及 1.6 倍。在 DeepMind Control Suite 上,CURL 是第一個(gè)幾乎與基于狀態(tài)特征方法的 sample-efficiency 和性能所匹配的基于圖像的算法。
背景介紹
CURL 是將對(duì)比學(xué)習(xí)與 RL 相結(jié)合的通用框架。理論上,可以在 CURL pipeline 中使用任一 RL 算法,無(wú)論是同策略還是異策略。對(duì)于連續(xù)控制基準(zhǔn)而言(DM Control),研究團(tuán)隊(duì)使用了較為熟知的 Soft Actor-Critic(SAC)(Haarnoja et al., 2018) ;而對(duì)于離散控制基準(zhǔn)(Atari),研究團(tuán)隊(duì)使用了 Rainbow DQN(Hessel et al., 2017))。下面,我們簡(jiǎn)要回顧一下 SAC,Rainbow DQN 以及對(duì)比學(xué)習(xí)。
Soft Actor Critic
SAC 是一種異策略 RL 算法,它優(yōu)化了隨機(jī)策略,以最大化預(yù)期的軌跡回報(bào)。像其他 SOTA 端到端的 RL 算法一樣,SAC 在從狀態(tài)觀察中解決任務(wù)時(shí)非常有效,但卻無(wú)法從像素中學(xué)習(xí)有效的策略。
Rainbow
最好將 Rainbow DQN(Hessel et al., 2017)總結(jié)為在原來(lái)應(yīng)用 Nature DQN 之上的多項(xiàng)改進(jìn)(Mnih et al., 2015)。具體來(lái)說(shuō),深度 Q 網(wǎng)絡(luò)(DQN)(Mnih et al., 2015)將異策略算法 Q-Learning 與卷積神經(jīng)網(wǎng)絡(luò)作為函數(shù)逼近器相結(jié)合,將原始像素映射到動(dòng)作價(jià)值函數(shù)里。
除此之外,價(jià)值分布強(qiáng)化學(xué)習(xí)(Bellemare et al., 2017)提出了一種通過(guò) C51 算法預(yù)測(cè)可能值函數(shù) bin 上的分布技術(shù)。Rainbow DQN 將上述所有技術(shù)組合在單一的異策略算法中,用以實(shí)現(xiàn) Atari 基準(zhǔn)的最新 sample efficiency。此外,Rainbow 還使用了多步回報(bào)(Sutton et al.,1998)。
對(duì)比學(xué)習(xí)
CURL 的關(guān)鍵部分是使用對(duì)比無(wú)監(jiān)督學(xué)習(xí)來(lái)學(xué)習(xí)高維數(shù)據(jù)的豐富表示的能力。對(duì)比學(xué)習(xí)可以理解為可區(qū)分的字典查找任務(wù)。給定一個(gè)查詢 q、鍵 K= {k_0, k_1, . . . } 以及一個(gè)明確的 K(關(guān)于 q)P(K) = ({k+}, K \ {k+}) 分區(qū),對(duì)比學(xué)習(xí)的目標(biāo)是確保 q 與 k +的匹配程度比 K \ {k +} 中的任何的鍵都更大。在對(duì)比學(xué)習(xí)中,q,K,k +和 K \ {k +} 也分別稱為錨點(diǎn)(anchor),目標(biāo)(targets),正樣本(positive), 負(fù)樣本(negatives)。
CURL 具體實(shí)現(xiàn)
CURL 通過(guò)將訓(xùn)練對(duì)比目標(biāo)作為批更新時(shí)的輔助損失函數(shù),在最小程度上改變基礎(chǔ) RL 算法。在實(shí)驗(yàn)中,研究者將 CURL 與兩個(gè)無(wú)模型 RL 算法一同訓(xùn)練——SAC 用于 DMControl 實(shí)驗(yàn),Rainbow DQN 用于 Atari 實(shí)驗(yàn)。
總體框架概述
CURL 使用的實(shí)例判別方法(instance discrimination)類似于 SimCLR、MoC 和 CPC。大多數(shù)深度強(qiáng)化學(xué)習(xí)框架采用一系列堆疊在一起的圖像作為輸入。因此,算法在多個(gè)堆疊的幀中進(jìn)行實(shí)例判別,而不是單一的圖像實(shí)例。
研究者發(fā)現(xiàn),使用類似于 MoCo 的動(dòng)量編碼流程(momentum encoding)來(lái)處理目標(biāo),在 RL 中性能較好。最后,研究者使用一個(gè)類似于 CPC 中的雙線性內(nèi)積來(lái)處理 InfoNCE score 方程,研究者發(fā)現(xiàn)效果比 MoCo 和 SimCLR 中的單位范數(shù)向量積(unit norm vector products)要好。對(duì)比表征和 RL 算法一同進(jìn)行訓(xùn)練,同時(shí)從對(duì)比目標(biāo)和 Q 函數(shù)中獲得梯度??傮w框架如下圖所示。
判別目標(biāo)
選擇關(guān)于一個(gè)錨點(diǎn)的正、負(fù)樣本是對(duì)比表征學(xué)習(xí)的其中一個(gè)關(guān)鍵組成部分。
不同于在同一張圖像上的 image-patches,判別變換后的圖像實(shí)例優(yōu)化帶有 InfoNCE 損失項(xiàng)的簡(jiǎn)化實(shí)例判別目標(biāo)函數(shù),并需要最小化對(duì)結(jié)構(gòu)的調(diào)整。在 RL 設(shè)定下,選擇更簡(jiǎn)化判別目標(biāo)的理由主要有如下兩點(diǎn):
鑒于 RL 算法十分脆弱,復(fù)雜的判別目標(biāo)可能導(dǎo)致 RL 目標(biāo)不穩(wěn)定。
RL 算法在動(dòng)態(tài)生成的數(shù)據(jù)集上進(jìn)行訓(xùn)練,復(fù)雜的判別目標(biāo)可能會(huì)顯著增加訓(xùn)練所需時(shí)間。
因此,CURL 使用實(shí)例判別而不是 patch 判別。我們可將類似于 SimCLR 和 MoCo 這樣的對(duì)比實(shí)例判別設(shè)置,看做最大化一張圖像與其對(duì)應(yīng)增廣版本之間的共同信息。
查詢-鍵值對(duì)的生成
類似于在圖像設(shè)定下的實(shí)例判別,錨點(diǎn)和正觀測(cè)值是來(lái)自同一幅圖像的兩個(gè)不同增廣值,而負(fù)觀測(cè)值則來(lái)源于其他圖像。CURL 主要依靠隨機(jī)裁切數(shù)據(jù)增廣方法,從原始渲染圖像中隨機(jī)裁切一個(gè)正方形的 patch。
研究者在批數(shù)據(jù)上使用隨機(jī)數(shù)據(jù)增廣,但在同一堆幀之間保持一致,以保留觀測(cè)值時(shí)間結(jié)構(gòu)的信息。數(shù)據(jù)增廣流程如圖 3 所示。
相似度量
區(qū)分目標(biāo)中的另一個(gè)決定因素是用于測(cè)量查詢鍵對(duì)之間的內(nèi)部乘積。CURL 采用雙線性內(nèi)積 sim(q,k)= q^TW_k,其中 W 是學(xué)習(xí)的參數(shù)矩陣。研究團(tuán)隊(duì)發(fā)現(xiàn)這種相似性度量的性能優(yōu)于最近在計(jì)算機(jī)視覺(如 MoCo 和 SimCLR)中最新的對(duì)比學(xué)習(xí)方法中使用的標(biāo)準(zhǔn)化點(diǎn)積。
動(dòng)量目標(biāo)編碼
在 CURL 中使用對(duì)比學(xué)習(xí)的目標(biāo)是訓(xùn)練從高維像素中能映射到更多語(yǔ)義隱狀態(tài)的編碼器。InfoNCE 是一種無(wú)監(jiān)督的損失,它通過(guò)學(xué)習(xí)編碼器 f_q 和 f_k 將原始錨點(diǎn)(查詢)x_q 和目標(biāo)(關(guān)鍵字)x_k 映射到潛在值 q = f_q(x_q) 和 k = f_k(x_k) 上,在此團(tuán)隊(duì)?wèi)?yīng)用相似點(diǎn)積。通常在錨點(diǎn)和目標(biāo)映射之間共享相同的編碼器,即 f_q = f_k。
CURL 將幀-堆棧實(shí)例的識(shí)別與目標(biāo)的動(dòng)量編碼結(jié)合在一起,同時(shí) RL 是在編碼器特征之上執(zhí)行的。
CURL 對(duì)比學(xué)習(xí)偽代碼(PyTorch 風(fēng)格)
實(shí)驗(yàn)
研究者評(píng)估(i)sample-efficiency,方法具體為測(cè)量表現(xiàn)最佳的基線需要多少個(gè)交互步驟才能與 100k 交互步驟的 CURL 性能相匹配,以及(ii)通過(guò)測(cè)量 CURL 取得的周期回報(bào)值與最佳表現(xiàn)基線的比例來(lái)對(duì)性能層面的 100k 步驟進(jìn)行衡量。換句話說(shuō),當(dāng)談到數(shù)據(jù)或 sample-efficiency 時(shí),其實(shí)指的是(i),而當(dāng)談起性能時(shí)則指的是(ii)。
DMControl
在 DMControl 實(shí)驗(yàn)中的主要發(fā)現(xiàn):
CURL 是我們?cè)诿總€(gè) DMControl 環(huán)境上進(jìn)行基準(zhǔn)測(cè)試的 SOTA ImageBased RL 算法,用于根據(jù)現(xiàn)有的 Image-based 的基準(zhǔn)進(jìn)行采樣效率測(cè)試。在 DMControl100k 上,CURL 的性能比 Dreamer(Hafner 等人,2019)高 2.8 倍,這是一種領(lǐng)先的 model-based 的方法,并且數(shù)據(jù)效率高 9.9 倍。
從圖 7 所示的大多數(shù) 16 種 DMControl 環(huán)境中的狀態(tài)開始,僅靠像素操作的 CURL 幾乎可以進(jìn)行匹配(有時(shí)甚至超過(guò))SAC 的采樣效率。它是基于 model-based,model-free,有輔助任務(wù)或者是沒(méi)有輔助任務(wù)。
在 50 萬(wàn)步之內(nèi),CURL 解決了 16 個(gè) DMControl 實(shí)驗(yàn)中的大多數(shù)(收斂到接近 1000 的最佳分?jǐn)?shù))。它在短短 10 萬(wàn)步的時(shí)間內(nèi)就具有與 SOTA 相似性能的競(jìng)爭(zhēng)力,并且大大優(yōu)于該方案中的其他方法。
表 1. 在 500k(DMControl500k)和 100k(DMControl100k)環(huán)境步長(zhǎng)基準(zhǔn)下,CURL 和 DMControl 基準(zhǔn)上獲得的基線得分。
圖 4. 相對(duì)于 SLAC、PlaNet、Pixel SAC 和 State SAC 基線,平均 10 個(gè) seeds 的 CURL 耦合 SAC 性能。圖 6. 要獲得與 CURL 在 100k 訓(xùn)練步驟中所得分相同的分?jǐn)?shù),需要先行采用領(lǐng)先的 pixel-based 方法 Dreamer 的步驟數(shù)。
圖 7. 將 CURL 與 state-based 的 SAC 進(jìn)行比較,在 16 個(gè)所選 DMControl 環(huán)境中的每個(gè)環(huán)境上運(yùn)行 2 個(gè) seeds。
Atari
在 Atari 實(shí)驗(yàn)中的主要發(fā)現(xiàn):
就大多數(shù) 26 項(xiàng) Atari100k 實(shí)驗(yàn)的數(shù)據(jù)效率而言,CURL 是 SOTA PixelBased RL 算法。平均而言,在 Atari100k 上,CURL 的性能比 SimPLe 高 1.6 倍,而 Efficient Rainbow DQN 則高 2.5 倍。
CURL 達(dá)到 24%的人類標(biāo)準(zhǔn)化分?jǐn)?shù)(HNS),而 SimPLe 和 Efficient Rainbow DQN 分別達(dá)到 13.5%和 14.7%。CURL,SimPLe 和 Efficient Rainbow DQN 的平均 HNS 分別為 37.3%,39%和 23.8%。
CURL 在三款游戲 JamesBond(98.4%HNS),F(xiàn)reeway(94.2%HNS)和 Road Runner(86.5%HNS)上幾乎可以與人類的效率相提并論,這在所有 pixel-based 的 RL 算法中均屬首例。
表 2. 通過(guò) CURL 和以 10 萬(wàn)個(gè)時(shí)間步長(zhǎng)(Atari100k)為標(biāo)準(zhǔn)所獲得的分?jǐn)?shù)。CURL 在 26 個(gè)環(huán)境中的 14 個(gè)環(huán)境中實(shí)現(xiàn)了 SOTA。
安裝
所有相關(guān)項(xiàng)都在 conda_env.yml 文件中。它們可以手動(dòng)安裝,也可以使用以下命令安裝:
conda env create -f conda_env.yml
使用說(shuō)明
要從基于圖像的觀察中訓(xùn)練 CURL agent 完成 cartpole swingup 任務(wù),請(qǐng)從該目錄的根目錄運(yùn)行 bash script/run.sh。run.sh 文件包含以下命令,也可以對(duì)其進(jìn)行修改以嘗試不同的環(huán)境/超參數(shù)。
CUDA_VISIBLE_DEVICES=0 python train.py \
--domain_name cartpole \
--task_name swingup \
--encoder_type pixel \
--action_repeat 8 \
--save_tb --pre_transform_image_size 100 --image_size 84 \
--work_dir ./tmp \
--agent curl_sac --frame_stack 3 \
--seed -1 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 --batch_size 128 --num_train_steps 1000000
在控制臺(tái)中,應(yīng)該看到如下所示的輸出:
| train | E: 221 | S: 28000 | D: 18.1 s | R: 785.2634 | BR: 3.8815 | A_LOSS: -305.7328 | CR_LOSS: 190.9854 | CU_LOSS: 0.0000
| train | E: 225 | S: 28500 | D: 18.6 s | R: 832.4937 | BR: 3.9644 | A_LOSS: -308.7789 | CR_LOSS: 126.0638 | CU_LOSS: 0.0000
| train | E: 229 | S: 29000 | D: 18.8 s | R: 683.6702 | BR: 3.7384 | A_LOSS: -311.3941 | CR_LOSS: 140.2573 | CU_LOSS: 0.0000
| train | E: 233 | S: 29500 | D: 19.6 s | R: 838.0947 | BR: 3.7254 | A_LOSS: -316.9415 | CR_LOSS: 136.5304 | CU_LOSS: 0.0000
cartpole swing up 的最高分?jǐn)?shù)約為 845 分。而且,CURL 如何以小于 50k 的步長(zhǎng)解決 visual cartpole。根據(jù)使用者的 GPU 不同而定,大約需要一個(gè)小時(shí)的訓(xùn)練。同時(shí)作為參考,最新的端到端方法 D4PG 需要 50M 的 timesteps 來(lái)解決相同的問(wèn)題。
Log abbreviation mapping:
train - training episode
E - total number of episodes
S - total number of environment steps
D - duration in seconds to train 1 episode
R - mean episode reward
BR - average reward of sampled batch
A_LOSS - average loss of actor
CR_LOSS - average loss of critic
CU_LOSS - average loss of the CURL encoder
與運(yùn)行相關(guān)的所有數(shù)據(jù)都存儲(chǔ)在指定的 working_dir 中。若要啟用模型或視頻保存,請(qǐng)使用--save_model 或--save_video。而對(duì)于所有可用的標(biāo)志,需要檢查 train.py。使用 tensorboard 運(yùn)行來(lái)進(jìn)行可視化:
tensorboard --logdir log --port 6006
同時(shí)在瀏覽器中轉(zhuǎn)到 localhost:6006。如果運(yùn)行異常,可以嘗試使用 ssh 進(jìn)行端口轉(zhuǎn)發(fā)。
對(duì)于使用 GPU 加速渲染,確保在計(jì)算機(jī)上安裝了 EGL 并設(shè)置了 export MUJOCO_GL = egl。
關(guān)鍵詞: BAIR