ディープラーニングモデルの学習でGPUメモリが足りず、モデルサイズや画像サイズ、トークン長、バッチサイズを大きくできないなど困ったことはありませんか?
今回はディープラーニングモデル学習時のGPUメモリを大幅に節約できる、「Gradient Checkpointing」という手法を紹介します。
Gradient Checkpointingとは
Gradient Checkpointingは勾配計算を工夫することでメモリ消費量を大幅に抑えられる手法です。
そのトレードオフとして、学習時間が増加します(メモリを抑えられる分バッチサイズを増やすことができるので、実際にはその差し引きで高速化することもあります)。
近年のTransformer系のモデルはモデル自体のサイズが大きく、少ないGPUでこれらのモデルを学習させるにはGradient Checkpointingは必須と言ってよいほど重要な手法となっています。
使用方法
Gradient Checkpointingをスクラッチで実装するにはモデルの勾配計算に関する深い知識が必要となり、非常に難易度が高いです。
今回は簡単にGradient Checkpointingを使用する方法として下記の2通りを紹介します。
使用方法①:timmライブラリのモデルを使う
timm : GitHub - huggingface/pytorch-image-models
timmは数々の事前学習済みモデルを手軽に利用できる便利な画像認識ライブラリです。
(Kaggleの画像認識コンペでは必ずといってよいほど使われており、画像認識タスクにとって必須級のライブラリです)
timmを使えば事前学習済みモデルのインスタンス化とGradient Checkpointingの有効化を以下のように簡単に行うことができます。
import timm
model = timm.create_model(model_name="efficientnet_b0", pretrained=True, num_classes=10) # モデルのインスタンス化
model.set_grad_checkpointing() # Gradient Checkpointingを有効化
※モデルによってはGradient Checkpointingが未実装の場合もあります
使用方法②:Hugging Faceライブラリのモデルを使う
Hugging Face : Efficient Training on a Single GPU (huggingface.co)
Hugging FaceのモデルにはGradient Checkpointingが実装されているモデルが多くあります(一部のモデルでは非対応で使えないケースがあります)。
Gradient Checkpointingを有効化する方法はこちらも簡単で、モデルインスタンスに対して以下のようにgradient_checkpointing_enable()を設定するだけです。
from transformers import AutoConfig, AutoModel
from torch.utils.checkpoint import checkpoint
# initializing model
model_path = "microsoft/deberta-v3-base"
config = AutoConfig.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, config=config)
# gradient checkpointing
model.gradient_checkpointing_enable()
print(f"Gradient Checkpointing: {model.is_gradient_checkpointing}")
最後に
GPUのメモリ節約方法はいくつかありますが、その中でもGradient Checkpointingは実装済みのライブラリを使用すればお手軽に使用できるのでオススメです。
GPUメモリが足りず学習で困っている場合は是非試してみて下さい。