Employee Blog
社員ブログ

AI画像生成における学習 (その3)

AIの学習における問題点の理解と対策

はじめに

前回、ちょうど学習が開始するところまで終了しました。

今回はその学習結果をみながら、AIの学習における問題点と学習結果の評価について見ていきたいと思います。

1. 勾配爆発(発散)問題

学習曲線 順調ー学習が進まないー発散

勾配爆発(発散)問題は、発生すると 作成したモデルがだめになってしまいます。

発散前のモデルを確認した結果、学習自体はまだする余地があるとした場合、これ以上学習がすすめられないことは問題になります。

前回で紹介した TensorBoard の出力を見れば、作成したモデルを確認しなくとも学習が発散しているのかどうかがわかるのでTensorBoardのグラフを確認しながら学習の状況を把握します。

機械学習では、目標とどれくらいの差があるのかを loss という形で評価します。どっちに向かって学習を進めていけば loss がより小さくなるかを傾きから得て、そちらに向かって1歩進んで、またその地点での傾きを見てどっちに進めば…と、繰り返して 最終的に loss値 が一番小さい地点を探します。

勾配降下法 lr = 0.1

例えば上のようなパラメータの曲線があったとして、loss が最小の 一番下の地点を目指すとします。

ランダムなスタート地点(今回は左上)からスタートして、loss 値と その地点の勾配(gradient) を計算して少なくなる方向に向かって( descent ) パラメータを修正して再度結果を見るということを繰り返していきます。 学習率 (learning rate, lr) が 0.1 の場合は20 step くらいでだいたい一番下の地点で収束しました。

learning rate を もっと大きくするとどうでしょう。例えば 0.2 にすると

lr = 0.2

lr = 0.2 にすると 10 step 程度で一番下の地点に収束しました。

では、もっと調子に乗って lr = 1 ならどうでしょう。

lr = 1

一番下に収束するどころか、左と右を行ったり来たりするばかりで一向に収束しません。

これが更に大きくなると…

lr = 1.1

収束するどころか、どんどん値が大きくなってしまいました。これが「発散」している状態です。

学習率( lr ) が小さいと、発散はしないが収束するまでに時間がかかる。 lr が大きすぎると 学習が進まなくなったり、発散したりするということがわかると思います。

2. スケジューラ

lr をいくつにすればよいかについては、元のモデルと学習する対象がどれだけパラメータが違っているかによるために、一概に「この値が best」といえる値がとれません。小さくしておけば発散することはないので安全ですが、学習に時間がかかってしまいます。

そこで、スケジューラの登場です。最初は比較的大きな値にしておき、学習が進むにつれてだんだん小さい値を lr に設定することで学習にかかる時間を短くしながら発散も防ぎます。

画像生成AIの学習で使われるスケジューラは以下のようなものがあります。

  • constant 常に一定の値を使います。あまり使われません。
  • linear 開始時 lr = lr で 最終 epoch で lr = 0 になるように落ちる。
  • cosine 序盤ゆっくり減少、中盤で傾きがピーク、後半またゆっくり減少しまう。よく使われます。
  • cosine with restart cosine と同じように減少、0になったらまた スタート lr に戻る動作を複数回繰り返す どのくらいの step / epoch が最適かわからない場合が発散を防ぐためによく使われます。
スケジューラ関数
cosine with restarts

初めてでよくわからない学習の場合には cosine with restarts を指定して step/epoch を大きめにとり、途中経過を見ながら学習の進み度合いを見たりします。

3. 再度学習開始!

今回は、cosine_with_restarts を使って見ようと思います。
第1回のときに紹介した RedRayz / Kohya_lora_param_gui (https://github.com/RedRayz/Kohya_lora_param_gui) を起動して、詳細設定から

kohya_lora_param_gui 設定画面

ページ1 タブの スケジューラの項目を cosine_with_restarts に変更して、設定を反映して閉じるを押して閉じます。

kohya_lora_param_gui 詳細設定画面

出力ファイル名をわかりやすい名前に変更して学習開始!

出力ファイル名を変更して学習開始

学習の完了を、ご飯でもたべながら気長に待ちます。長い時間かかるので、寝る前に仕掛けて朝起きて確認する人もいます。

学習が完了したら TensorBoard のグラフを確認します。

lr = 0.001 constant(赤) と consine with restarts(青) の比較

前回のグラフが赤。700step / 45 epoch あたりで発散していましたが、今回は 800step / 50 epoch 完走しています。「発散を防ぐ」という意味では成功です! 

学習の結果は単純に loss 値の大小では一概に判断できませんが、このグラフだけから判断すると前回の学習より安全方向に倒した為に学習は進んでいないように見えます。

実際に画像を出力して確認します。

プロンプトは、「sail_kun, beach, masterpiece」 の beach の箇所を city, forest, desert に変えて、プロンプトの反応性を見ています。

Version1 と Version2 の比較評価

上段が前回作成した v01 の 34 epoch のもの、下段が今回作成した v02 です。

上段の v01 はなんとか セイルくんが出ています。前回は全く出ていなかった背景も city と forest で出そうと頑張っている雰囲気はあります。desert は色だけ?

v02 はそれに比べると明らかに学習不足です。目が2つ書かれたり、口が2つ書かれたり、形が崩れたりと散々です。v01 と比べ、学習率を減らす方向に制御したために、発散はしなかったものの学習不足になったようです。もう少し追加で学習する必要がありそうです。

4. まとめ

今回は課題1)勾配爆発 の対応について検討しました。学習率 ( learning rate, lr ) が大きすぎて発散してしまう問題を、今回はスケジューラの設定で回避しました。

次回は別の課題の解決に取り組んでいきたいと思います。

それでは今回はこれまで。


参照元

以下のサイトを参考しました。

【ディープラーニング入門(第4回)】勾配降下法を学んでディープラーニングの学習について理解しよう
https://qiita.com/kwi0303/items/7bfd7180f80a52296e64


利用したツールです。感謝!

kohya-ss / sd-scripts

https://github.com/kohya-ss/sd-scripts

RedRayz / Kohya_lora_param_gui

https://github.com/RedRayz/Kohya_lora_param_gui