masato-ka's diary

日々思ったこととか、やったことの備忘録。

JetBotのraod_followingサンプル学習時のGPUメモリ解放忘れについて

この記事について

この記事はAI RC Carアドベントカレンダー4日めの記事です。4日目の今日は小ネタ中の小ネタです。JetBotのサンプルに含まれるroad_followingの学習ノートブックの修正についてです。

road_followingは学習が遅い

JetBotのサンプルにはroad_followingと呼ばれる、コースを追従するためのサンプルがあります。モデルを学習するためのtrain_model.ipynbノートブックで学習を行おうとすると学習に時間がかかり、また500枚程度の画像で、GPUメモリが溢れてしまい学習ができなくなってしまいます。

loss値の解放忘れ

この現象の原因は学習時に各エポックごとのloss値の総和を求める際にlossの計算結果をGPUメモリに乗せたまま計算していることが原因と考えられます。 train_model.ipynbの最終ブロックをみてみると以下のようなコードをみることができます。

for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += loss
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)

この時、mse_lossの計算結果はGPUメモリ上に乗ったままになっています。ですので、上記コードの7行目をtrain_loss += float(loss)とします。これはfloatでキャストしているだけですが、Pytorchのloss.detach().cpu()と同じ効果を得られるようです。学習時とバリデーション時両方忘れずに修正しましょう。 この修正を夏頃プルリクエストで投げてみましたが、あまりにもニッチすぎるのか、それとも外部からのプルリクエストは受け付けない方針なのか、マージされません。ですので個別に修正することをお勧めします。

github.com

## まとめ

とても微妙な箇所ですが、実際にroad_followingの学習を初めてみると他のサンプルの学習に比べてすこい遅い感じがするので、road_followingを試す場合はこの修正をお勧めします。