機器學習中的Optimal Transport及相關問題:(二)計算方法

最近有不少小夥伴諮詢我關於最優傳輸理論(Optimal Transport)相關的計算方法,恰好我最近也在寫畢業論文,作為博士畢業論文的核心課題,那我就來簡單科普一下它跳坑的正確姿勢好了。相信大家可能都知道一些基礎的東西,比如它的定義以及可能的應用。從某種意義來說,這次OT在機器學習界的小高潮跟以前Kernel Method在機器學習界的發展非常相似,數學上可以推導出一些漂亮的性質,實踐上又能找到一些落地的場景(灌水利器)。

這篇文章是這個系列的第二篇(打算寫三到四篇,其他幾篇,包括第一篇會陸續放出)。第一篇會做一些初步的介紹,聊聊問題的背景和八卦,有興趣的同學可以先看看我以前的知乎回答:知乎用戶:分布的相似度(距離)用什麼模型比較好? 第二篇也就是這篇會著重介紹它目前領先的計算方法,第三、四篇可能會談談它在機器學習中的應用。

一般來說它的問題是這樣的:

作為一個非常經典的線性規劃問題,當前已有的線性規劃演算法已經能相對快速的對小規模問題求解了。但是如果要用在機器學習領域,依然還有兩個主要的計算問題:

  • 如果 m_1,m_2 很大怎麼辦?LP求解問題的計算規模是 O(m_1m_2(m_1+m_2)log (max{m_1,m_2})) [Orlin, 1993]
  • 如果不止一個,而是有大規模的不同OT問題要同時求解怎麼辦?

這個時候需要有一些近似的計算方法,能夠在 O(n^2/varepsilon^q) 的時間計算出 varepsilon -guarantee的解。

Entropic Regularization and Sinkhorn algorithm

無疑目前最流行的一個方法就是用entropic regularization把問題變成一個strongly convex的近似,並使用Sinkhorn演算法求解。簡單來說,就是求解如下問題:

這裡 H(Z) 是entropy function。[Cuturi, 2013]提出用Sinkhorn iteration來求解如上問題:

準確來說,Sinkhorn在實現上有兩種策略,一個是在log space上迭代(也就是上圖所示),一個是直接迭代 u=exp(mathbf x),v=exp(mathbf y) 。一般來說後一種實現出來的計算效率更高一些。那麼Sinkhorn演算法有什麼理論保證呢?最近的一篇文章中,[Altschuler et al. 2017]給出了一個不錯的結果,如果 A=exp(eta M) ,並且 U_{r,c}=Pi(mathbf p, mathbf q) ,Sinkhorn可以在 O((varepsilon)^2)(log n + eta |M|_{infty}) 的迭代次數得到一個近似解 hat Z 使得

langle hat Z, M
angle le min_{Zin Pi(mathbf p, mathbf q)}langle Z, M
angle + frac{2 log n}{eta} + 4varepsilon|M|_{infty},

並且

|r(hat Z) -mathbf p|_1 +|c(hat Z) - mathbf q|_1<varepsilon.

[Altschuler et al. 2017]進一步指出,給定任意的 varepsilon 只要選取合適的 eta,varepsilon ,我們可以在 O(m_1m_2/varepsilon^3) 的時間(near-linear)產生 varepsilon -guarantee in objective, varepsilon^2 -guarantee in constraints的解。然而Sinkhorn存在兩個主要問題使得它在現實中很難得到這樣的性能:

  • eta^{-1} 非常小的時候,演算法迭代若干次(遠少於理論bound需要的迭代規模)後就很容易超出浮點精度。
  • eta^{-1} 比較大,迭代次數相對少的情況下,Sinkhorn解雖然線性收斂到一個smooth的approximation (Eq. (3.5)),但是對原問題目標函數(Eq. (3.1))的近似效果就非常差了。

這兩個問題在我的TSP文章中有過比較仔細的討論 [Ye et al. 2017]。

相關文獻:

Cuturi, Marco. "Sinkhorn distances: Lightspeed computation of optimal transport." Advances in neural information processing systems. 2013.

Altschuler, Jason, Jonathan Weed, and Philippe Rigollet. "Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration." Advances in Neural Information Processing Systems. 2017.

Bregman ADMM

近似求解OT還有一個不太為人了解的方法,這個方法據我所知是[Wang and Banerjee, 2014]最早提出的。然而因為文章本身並不是主要服務於OT的literature,而且證明的理論結果比較general(複雜),很少被人提及。方法的基本想法類似經典的ADMM,先把原問題寫成等價的如下問題:

min_{substack{r(Z_1)=mathbf p\ c(Z_2) = mathbf q}} langle Z_1, M
angle mbox{ s.t. } Z_1 = Z_2.

然後給最後一個等式約束做method of multiplier:

egin{eqnarray*} Z_1&:=&mbox{argmin}_{r(Z_1) = mathbf p} langle Z_1, M 
angle + langle Lambda, Z_1 
angle + underbrace{
ho cdot mbox{KL}(Z_1, Z_2)}_{	ext{replace $|cdot|^2$ with $B_{Phi}(cdot,cdot)$}}\ Z_2&:=&mbox{argmin}_{c(Z_2)=mathbf q} -langle Lambda, Z_2
angle + 
ho cdot mbox{KL}(Z_2, Z_1)\ Lambda &:= &Lambda + 
ho (Z_1 - Z_2) end{eqnarray*}

於是得到一下演算法:

這個方法也是有理論bound的,具體來說我們定義D(W^ast ,W^t) = mbox{KL}(Z^ast, Z_2^t) + dfrac{1}{
ho^2} |Lambda^ast - Lambda^t|^2那麼我們有

langle ar{Z}_1^T, M
angle - langle Z^ast, M
angle le frac{
ho mbox{KL}(Z^ast, Z_2^0)}{T},

以及

|ar{Z}_1^T - ar{Z}_2^T|_1 le sqrt{m_1m_2} |ar{Z}_1^T - ar{Z}_2^T|_2le sqrt{dfrac{2D(W^ast, W^0)m_1m_2}{T}}

其中  ar{Z_j}^T=frac{1}{T}sum_{t=1}^T Z_j^t, j=1,2

相關文獻:

Wang, Huahua, and Arindam Banerjee. "Bregman alternating direction method of multipliers." Advances in Neural Information Processing Systems. 2014.

兩種方法的比較

可以看到給定相同的迭代次數,我們可以有效對比兩個方法的收斂rate

下面的這個表格簡單概括了這個結果

理論上來說,在 sqrt{m_1m_2} ll T 的情況下,只要在B-ADMM中選擇合適的 
ho就可以得到一個收斂更快的解用來近似原始的目標函數(Eq. (1.3)),但是這個解相比Sinkhorn的解更不容易滿足constraints 。[Ye et al. 2017]詳細比較了這兩個方法,理論解釋和該篇文章中的實驗結果也是匹配的。值得一提的是,在大多數機器學習的應用中,嚴格滿足兩個marginal constraint並不是必須的,但有一個合理的方法近似目標函數卻是十分必要的。 下面這個圖,在我過去的talk中貼了很多次,是一個直觀比較收斂特性的toy example。

除了Sinkhorn和B-ADMM,還有一些別的近似方法,比如我去年的ICML文章用Sampling的辦法來近似求解OT,著重處理OT優化中warm-start的情況。今年也有ICML的submission用Proximal Point Method來求解OT。

相關代碼:

bobye/OT_demo

相關文獻:

Ye, Jianbo, et al. "Fast discrete distribution clustering using Wasserstein barycenter with sparse support." IEEE Transactions on Signal Processing 65.9 (2017): 2317-2332.

Ye, Jianbo, James Z. Wang, and Jia Li. "A Simulated Annealing Based Inexact Oracle for Wasserstein Loss Minimization." International Conference on Machine Learning. 2017.

Xie, Yujia, et al. "A Fast Proximal Point Method for Wasserstein Distance." arXiv preprint arXiv:1802.04307(2018).

從OT到Wasserstein barycenter

相比求解單個OT問題,Wasserstein barycenter(WBC)把多個OT問題couple在一起,這種情況在把Wasserstein distance當作loss function的機器學習問題非常常見。WBC的問題簡單來說就是給一組分布,求解它們的中心:

 min_{P} frac{1}{N} sum_{k=1}^{N} W^2 ( P,P^{( k )} )

在Sinkhorn或者B-ADMM的框架下,這個經典問題都可以得到有效求解。在Sinkhorn框架下,[Benamou et al., 2015]提出iterative Bregman projection來求解WBC,在B-ADMM框架下[Ye et al. 2017]提出的WBC辦法可以用來解決在Wasserstein space類似K-means的問題。

相關代碼:

bobye/WBC_Matlab

相關文獻:

Ye, Jianbo, et al. "Fast discrete distribution clustering using Wasserstein barycenter with sparse support." IEEE Transactions on Signal Processing 65.9 (2017): 2317-2332.

Benamou, Jean-David, et al. "Iterative Bregman projections for regularized transportation problems." SIAM Journal on Scientific Computing 37.2 (2015): A1111-A1138.


推薦閱讀:

【學界】關於KKT條件的深入探討
變の貝葉斯
【學界/編碼】凸優化演算法 I: 內點法(interior point method)求解線性規劃問題

TAG:機器學習 | 數值計算 | 凸優化 |