JAXによるスケーラブルな機械学習

OGP

はじめに

こんにちは、ZOZO NEXT ZOZO ResearchのSai Htaung Khamです。ZOZO NEXTは、ファッション領域におけるユーザーの課題を想像しテクノロジーの力で解決すること、より多くの人がファッションを楽しめる世界の創造を目指す企業です。

ZOZO NEXTでは多くのアルゴリズムを研究開発しており、その中でJAXというライブラリを使用しています。JAXは高性能な機械学習のために設計されたPythonのライブラリです。NumPyに似ていますが、より強力なライブラリであると考えることができます。NumPyとは異なり、JAXはマルチGPU、マルチTPU、そして機械学習の研究に非常に有用な自動微分(Autograd)をサポートしています。

JAXはNumPyのAPIのほとんどをミラーリングしているので、NumPyライブラリに慣れている人なら非常に導入しやすいです。Autogradを使えば、Pythonのネイティブ関数とNumPyの関数を自動的に微分できます。JAXの詳細な機能については、JAXの公式GitHubリポジトリを参照してください。

そもそも、なぜJAXなのか?

機械学習アルゴリズムを構築する際、多くのMLエンジニアはTensorflowやPyTorchといった信頼性の高いMLフレームワークを利用することでしょう。成熟したMLフレームワークには成熟したエコシステムがあり、本番環境への統合や保守が容易になるため、良い決断です。当研究所では、これらのフレームワークを用いて実装された多くのアルゴリズムを持っています。

しかし、いくつかのアルゴリズムはNumPyライブラリを用いて、純粋なPythonで実装されています。その中には例えば、研究者やMLエンジニアが社内用に設計した埋め込みアルゴリズムがあります。埋め込みアルゴリズムは類似商品を効率よく抽出できるため、商品推薦などに有用です。実装がPythonであるため、このアルゴリズムは計算の実行時間にボトルネックがあります。お気づきのように、フレームワークを使用しない場合、パラメータの更新やモデルのダンプ等も自前で実装する必要があります。そのため、新しいアイデアをすぐに試すことが難しく、なかなか前に進めません。また、ライブラリや学習プロセスもCPUデバイスに限定されるため、拡張性がありません。共有メモリアーキテクチャを利用してマルチプロセスでアルゴリズムを実行できましたが、GPUやTPUなどの複数のホストやデバイスで実行し、垂直方向・水平方向にスケールできる状態が望ましいです。

そこで、拡張性・保守性の高い別のフレームワークにプログラムを移植する方法を検討した結果、以下のような特徴を持つJAXを採用しました。

  1. Single Program Multiple Dataアーキテクチャによる水平方向のスケーラビリティ
  2. NumPyのAPIをミラーリング
  3. Pythonに対応
  4. Autogradのサポート
  5. エコシステムまでオープンソース化されている(FlaxやHaikuなど)

特に(2)の性質によってNumPyで書かれたアルゴリズムを効率よく移植できるという点が、既存の他のフレームワークにはない利点でした。

本記事を読むことで分かること

本記事では、実世界のデータを使った機械学習を、JAXライブラリで実行する方法について説明します。通常、機械学習の理論を学び問題を解くときには、理解を深めるために小さなデータを使用します。しかし、実世界のデータに応用するとデータ量、モデルを格納するメモリサイズ、学習と評価のスケーラビリティなど多くの困難に直面することになります。ありがたいことに、現代のクラウドコンピューティングの革命と価格設定により、スケーラブルな機械学習は誰でも利用できるようになりました。

典型的な機械学習パイプライン

典型的な機械学習プロジェクトはデータの準備からモデルのサービングまで多くのステージで構成されますが、本記事で取り扱うのはデータの準備とモデルの学習に当たる部分です。特にJAXライブラリのパフォーマンスと、クラウドコンピューティング上でのスケーラビリティを実現する方法について説明します。

データって本当に大きいの? いつ、どこで、どうやって処理するの?

データと質の高いデータ変換が機械学習プロジェクトの成功の中心であることは、すべてのMLエンジニアが理解していることです。現実の機械学習プロジェクトでは、1台のマシンでETL(抽出、変換、ロード)プロセスを行えるような量のデータを扱うことは稀です。当研究所では、Google CloudやAWSなどのクラウドコンピューティングリソースに広く依存しており、通常、クラウドストレージやクラウドデータウェアハウスを使用してデータを管理しています。

データストレージとアルゴリズムをホストするマシンの統合

ボトルネックに要注意!

クラウドストレージは、1台のマシンに収まりきらない大量のデータを保存するのにとても役立ちます。しかし、モデルの学習に利用するためには、ストレージからデータを読み出す効率的な方法を見つける必要があります。多くのMLエンジニアがGPUデバイスを使った学習中に遭遇する問題の1つは、GPUデバイスが十分に活用し切れず、学習プロセスに必要以上の時間がかかってしまうことです。次のTensorFlowモデルのプロファイリング結果をご覧ください。

モデルのプロファイリング 参照:[モデルのプロファイリング]

よく観察すると、ディスクからデータを取得している間、GPUデバイスはほとんどの時間、アイドル状態であることに気づかれると思います。一般的には、学習中はGPUデバイスをビジー状態にしたいものです。これは、データ入力パイプラインにボトルネックがあることを示しています。

TF Dataってヒーローなの?

データ入力パイプラインのボトルネックを解消するために、TF DataというTensorFlowが提供する便利なツールを利用することにします。従来の方法では、下図のようにディスクからデータを順次読み込んでいました。下図のMapは、正規化、画像補強などのデータの変換処理です。

モデルへのデータの順次取り込み 参照:[モデルへのデータの順次取り込み]

しかし、この方法では学習処理にデータ転送待ちが発生し、GPUデバイスがアイドル状態になってしまうというボトルネックが発生しています。そこで下図のように読み込みとデータ変換を並列に行うことで、学習の待ち時間が少なくなります。

TF Data Pipelineによる効率的なデータ変換 参照:[TF Data Pipelineで効率的なデータ変換]

TFデータパイプラインのコンポーネントは再利用可能です。トレーニングやサービングフェーズに適用できます。TF DataライブラリはホストCPUを利用してデータを並列に処理しているので、CPUの性能が高ければ高いほど、データの読み込みや前処理が高速になることを念頭に置いておくことが重要です。

データ前処理パイプラインとして、Apache BeamやTFX Transformを使用する方法もありますが、今回は説明しません。本記事では、TF DataとJAXを使用して、スケーラブルな機械学習を共有します。

処理を高速化してみよう!

効果的なデータ前処理パイプラインを手に入れたことで、モデルの学習と評価のステップに移行します。JAXの便利なライブラリにvmapとpmapがあります。本記事では、vmapとpmapを使用してマルチGPUデバイスでの学習処理を高速化します。

#vmapによるauto-vectorization
import numpy as np
import jax.numpy as jnp
import jax

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

x = np.arange(5)
w = np.array([3., 1., 3.])
batch_size = 10
xs = np.arange(5 * batch_size).reshape(-1, 5)
ws = np.stack([w] * batch_size)
print(f"The shape of the x and w : {xs.shape, ws.shape}")

print("Process each sample.")
for sample in xs:
    print(convolve(sample, w))

print("Auto-vectorization with vmap:")
print(jax.vmap(convolve)(xs, ws))
#vmap処理とサンプル単位処理の比較結果
The shape of the x and w : ((10, 5), (10, 3))
Process each sample.
[ 7. 14. 21.]
[42. 49. 56.]
[77. 84. 91.]
[112. 119. 126.]
[147. 154. 161.]
[182. 189. 196.]
[217. 224. 231.]
[252. 259. 266.]
[287. 294. 301.]
[322. 329. 336.]
Auto-vectorization with vmap:
[[  7.  14.  21.]
 [ 42.  49.  56.]
 [ 77.  84.  91.]
 [112. 119. 126.]
 [147. 154. 161.]
 [182. 189. 196.]
 [217. 224. 231.]
 [252. 259. 266.]
 [287. 294. 301.]
 [322. 329. 336.]]

まずはvmapに関して説明します。vmapはコードを変更することなく関数をベクトル化(auto-vectorization)するものです。auto-vectorizationにより、vmap APIで関数をラップする以外にコードを変更することなく処理を高速化できます。これは、特にバッチ処理の際に非常に便利です。vmapの機能はまだまだあるので、以下のリンクから確認してください。

jax.readthedocs.io

pmapの使い方は、vmapとよく似ています。しかし、pmapはMPIのようなCollective operationを提供し、プログラムが複数のデバイス上で通信しデバイスをまたいで合計や平均などの演算「MapReduce」を実行できます。このAPIにより、プログラムはスケールアウトできます。

#マルチデバイスでpmapを適用する
@partial(jax.pmap, axis_name="num_devices")
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
    loss, grads = jax.value_and_grad(loss_func)(params, x, y)

    grads = jax.lax.pmean(grads, axis_name="num_devices")
    loss = jax.lax.pmean(loss, axis_name="num_devices")

    new_params = jax.tree_multimap(
        lambda param, g: param - g * step_size, params, grads
    )

    return new_params, loss

上記のコードサンプルでは、異なるデバイスでloss_func関数から返された結果に対してCollective meanを実行し、パラメータを更新しています。このコードブロックは、アクセラレータデバイスの数を気にすることなく、どのマシン上でも実行できます。バックグラウンドでJAXによって自動的にスケールアウトし、管理されます。ただし、アクセラレータデバイスの数に応じて、デバイスの次元を一致させる必要があります。デバイスの次元とは、デバイス間でデータを均等に分割するために使用される仮想的なメトリックディメンジョンのことです。例えば、8台のデバイスがある場合、同時に処理するサンプルは少なくとも8個必要です。

環境設定

本記事では、JAXライブラリを用いて2つのデータセットを検証します。1つ目はMNISTの手書き数字データセット、2つ目はカスタムデータセットです。まずはMNIST手書き数字データセットのためのシンプルな多層パーセプトロン(MLP)を構築しました。

以下の図は使用したインフラ設定です。

使用したインフラ設定

MLPのハイパーパラメータ設定です。

MLPのハイパーパラメータ

以下の図はアクセラレータ毎の平均実行の時間、異なるアクセラレータでのアルゴリズム実行時間の比較です。

アクセラレータ毎の平均実行時間

JAXが提供するpmap APIを使えば、簡単に複数のデバイスでモデルを実行し、学習と配信のためにスケールアウトさせることができました。CPUでは各エポックに約3.34秒かかるのに対し、4GPUでは1.09秒であることが確認されました。この図は、より多くの並列処理を行うほど、この特定のアルゴリズムの実行時間が短縮することを示しています。

以下の4GPUでの学習と各エポックでの実行時間図は、4つのGPUアクセラレータを用いた場合の、各エポックにおけるモデル学習と実行時間の性能を示しています。

4GPUで学習と各エポックの実行時間 4GPUで学習と各エポックの実行時間

上の図から、モデルはトレーニングデータセットでうまく学習し、バリデーションデータセットでもうまくいっていることが確認できました。また、学習処理は4つのGPUデバイス全てに均等に分散されています。最初のエポックを終えるのに約1.8秒、その後のエポックでは約1.09秒かかっています(左下図)。最初のエポックでは、クラウドストレージから画像をリモートで読み込んで、パイプラインデータ変換に応じた前処理を行う必要があります。その後、パイプラインのキャッシュ機能を使ってデータをローカルにキャッシュし、次のエポックに備えることで、実行時間を大幅に短縮しています。

GPU使用率の観点から、GPUは最初からビジー状態であり、トレーニングの最初のエポックの終わりである1.87秒付近にいくつかのピークあることがわかりました(右下図)。これは、GPU(特にgpu_2)がパイプラインからのデータロードと変換同時にいくつかの処理を保持していることを物語っています。データパイプラインがリモートストレージデバイスからデータをロードするのと並行して、学習処理が開始されていることがわかります。GPUデバイスは約40%のピークにあり、すべてのGPUを100%利用するにはMLPのレイヤーがかなり小さいので妥当なところです。

私たちが発見した興味深い事実は、クラウドストレージの場所は私たちのアルゴリズムをホストしているマシンと異なる場合、リモートデータ取得により最初のエポックに余分な時間が追加されるということです。これは通常、クラウドインフラの仕様で、ユーザーがアクセスするエッジロケーションにデータをダウンロードする必要があるためです。ホストマシンから初めてアクセスした後、データはエッジロケーションに保存され、トレーニングの実行時間が大幅に改善されます。

以下の図はトレーニングマシンとクラウドストレージが異なる地域または場所にある場合の結果です。

トレーニングマシンとクラウドストレージが異なる地域または場所にある場合

最初のアクセス後、エッジロケーションのキャッシュにより、ランタイムが改善されました。

最初のアクセス後、エッジロケーションのキャッシュにより、ランタイムが改善されました

計算量が大きいほどGPUのデバイス使用率が高くなることを検証するため、より大きなMLPレイヤーと大きな画像でテストを行いました。ハードウェアの仕様は、前回のMNIST手書きデータセットでの実験と同じにしています。

Custom Dataset Environment Setting

以下の図は各エポックにおける異なるアクセラレータのアルゴリズム実行時間の比較です。

各エポックにおける異なるアクセラレータの平均実行時間

このシナリオでは、シングルGPUでの学習が最高のランタイムパフォーマンスをもたらすことが観察され、興味深いです。シングルGPUでは、CPUよりも約12倍高速になります。この特定のデータセットとマルチGPUデバイスの場合、マルチデバイスで実行する際のMap-Reduce操作のオーバーヘッドが原因だと思われます。

4GPUで学習と各エポックの実行時間 4GPUで学習と各エポックの実行時間

予想通り、層数と中間素子数が増えれば増えるほど、計算負荷が大きくなります。最初のエポック(0-28秒)の間(右下図)、ホスト上で起きているデータの前処理とGPU上のトレーニングステップが同時に実行されていることが観察されます。もちろん、MLPのレイヤーに生のピクセルを入力しているため、モデルの学習にはあまり期待できません。より良い結果を得るためには、畳み込みニューラルネットワークを使用することが望ましいでしょう。

まとめ

結論として、並列処理とキャッシュを備えたTF Dataライブラリを使用することでGPUデバイスのポテンシャルを引き出し、より高速な学習が可能になることが確認できました。GoogleやAWSのような大手ベンダーのクラウドストレージにデータを保存しつつ、データ取得を高速に行うことを説明しました。TF Dataではクラウドストレージだけでなく、BigQueryやBigtableなどからもリモートでデータを読み込むことができます。詳しい使い方はドキュメントをご覧ください。

また、マルチデバイス処理に関するvmapやpmapなど、JAXの便利な機能のデモンストレーションをしました。JAXは、NumPyのAPIのほとんどがJAXでミラーリングされているため、NumPyに慣れている人であれば簡単に使用できます。さらに、Autogradは、Pythonのネイティブ関数とNumPyの関数の微分を自動化できます。pmapの使用に適合するようにプログラムを開発すれば、JAXがバックグラウンド処理を引き受けてくれるのでCPUやマルチGPUデバイスに関係なくどこでもこのプログラムを実行できます。

私見ですが、JAXは非常に柔軟な使い方ができ、PoCを素早く行うためのコーディングが容易です。しかし、機械学習アルゴリズムを複数のプラットフォームで運用することを目指すのであれば、JAXはまだ成熟していないと言えます。TensorflowやPyTorchのような、強力なエコシステムを持ち、広く採用されている他のフレームワークに目を向けたほうがよいと思います。

さらに、JAXによるスケーラブルなインフラを実証するため、シンプルなMLPアルゴリズムを採用しました。JAXの複雑で高度なモデルを使って、MLの問題を解決出来ます。コードを入れ替えるだけ、本記事で取り上げたことはほとんど同じです。私は、FlaxやHaikuのような深層学習用のJAXフレームワークを使用することをお勧めします。JAXの公式GitHubリポジトリのチェックを忘れないでください。JAXを使ったハンズオンを楽しんでください。

本記事をシンプルにするため、コードブロックでの説明を省略し、すべてJupyter Notebookにまとめました。ぜひご覧ください。

ZOZO NEXTでは、機械学習を適切に使用して課題を解決できるMLエンジニアを募集しています。今回は、JAXを使ってレガシーアルゴリズムを改善した方法を紹介しました。MLアルゴリズムをファッションビジネスへ応用することに興味がある方は、以下のリンクからぜひご応募ください。

hrmos.co

hrmos.co

hrmos.co

カテゴリー