誰でもわかるStable Diffusion その6:U-Net(IN1、Resブロック)
Stable DiffusionのU-Net解説の3回目です。
今回はIN1ブロックの中にある「Resブロック」を詳しく見ます。
- IN1ブロックの構造
- Resブロックの構造
- Group Norm:グループごとに正規化
- SiLU:非線形な要素を加える
- Conv:絵の特徴を抽出
- Time Embed:ノイズの量を知らせるスケジュール係
- 処理後に元データを足す:「Res」という名前の由来
- まとめ
IN1ブロックの構造
上のU-Netの図の左上にあるIN1ブロックは「Resブロック」(青色)と「Attentionブロック」(オレンジ色)の2つのブロックからできています。
Resブロックは、入力された画像が持っている特徴を探し出す処理を行います。また、画像に混ざっているノイズ量の情報を画像に追加します。
(Attentionブロックは次回見ます。)
Resブロックの構造
Resブロックが行っている処理は以下の図の通りです。
「Group Norm」→「SiLU」→「Conv」という処理を2回やっていることに気付かれるでしょう。この2回の処理の間に「Time Embed」という処理を行っています。
ちなみに「Dropout」という処理が2回目のConv前に置かれていますが、実際のStable Diffusionではこの処理はスキップされている(通る前と後でデータの中身が変わらない)ので無視してもかまいません。
Group Norm:グループごとに正規化
まずデータに対して「Group Normalization」という処理を行います。
(データというのは、前のブロックから出てきた画像データです。例えばIN1のResブロックの場合、IN0の結果をデータとして受け取ります)
入力データは320のチャンネル(特徴)を持っています。まず、この特徴チャンネルを32チャンネルずつ10個のグループに分けます。
このグループごとに数字を正規化します。グループ内の平均と分散の値にもとづいて正規化します。たとえて言うならグループ内の偏差値を計算するようなものです。
あるマスの値がグループ内の平均値と全く同じなら、正規化後の値は0になります*1。
この処理の意味は、おもに計算の効率化のためです。
ちなみに、この処理も「学習によって賢くなる」パーツです。
SiLU:非線形な要素を加える
次に、「活性化関数」と呼ばれるものにデータを通します。Stable Diffusionでは「SiLU」という名前の関数を使います。
この関数に通すと、マイナスの数値は「ほとんど0」になります。プラスの数値は「だいたい同じ」数値が出てきます。やっているのはこれだけです。
なぜこんなことをするのかというと、Resブロックが「非線形関数」であってほしいからです。
「線形関数」だと複雑な処理ができない、というくらいの理解でかまいません。
何億枚という絵を記憶するには単純な関数だと役不足なのです。
Conv:絵の特徴を抽出
最後に「畳み込み」処理(Convolution)を行って、画像が持つ特徴を抽出します。これは前回の記事で説明しました。
このブロックでやっているConv処理を図で表すと下のような感じです。
IN0で行ったConv処理は特徴チャンネルが4から320に増えました。
しかしここでは、入力データも出力データも特徴チャンネルは320個で、Conv前後でデータの形は全く変わりません。
では、何もやっていないのかというと、もちろんそんなことはありません。
上の図のフィルターを見てください。マス数は3x3ですが、チャンネル数が320になっています。この3x3x320というフィルターを使って畳み込みをすると、出力として新たに1チャンネル分のデータが出てきます。
こうしたフィルターが320個あるので、新たな320チャンネルのデータが作られるわけです。
このフィルターは何をしているのでしょう?
あるフィルターをデータに重ねてみましょう。フィルターは重ねたエリアにある320個の特徴をそれぞれ見ます。例えばそのエリアで「タテ棒」チャンネルと「ヨコ棒」チャンネルが大きな数値を持っているなら、きっとそのエリアはより複雑な「十字」という特徴を持っているでしょう*2。そうやって単純な特徴からより複雑な特徴を見つけ出すことができます。
一般的に、Conv処理を重ねれば重ねるほど、より複雑な特徴を検出するようになります。
Time Embed:ノイズの量を知らせるスケジュール係
畳み込みが終わったら、今の状態の画像にどれだけノイズが混ざっているか、という情報を画像に追加します。
画像データは完全なノイズ画像から始まって、これを何度もU-Netに通すことで徐々にノイズが取り除かれて絵になっていきますが、U-Netというのは1つしか存在しません。ノイズがたくさん乗っているときも、ほとんど取り除かれているときも、同じU-Netを使います。
そのため、「画像にどれくらいノイズが乗っているか」という情報がないと、U-Netがどれだけノイズを取り除いていいのか分からなくなってしまいます。
そこで、Time Embed処理で「今、絵がどれくらいの段階にいるか」という「時間情報」を絵のデータに追加します。
この情報を画像に持たせることで、「どれくらいの量のノイズを取り除いたらいいか」という情報をこの後のブロックに伝えることができます。
時間情報は1280個の数字からなるベクトルです。このベクトルをどうやって作るかの説明は省きますが、このベクトルによって今の画像がどの段階にいるのかを知ることができます。
さて、このベクトルを画像に足したいのですが、画像は320チャンネル(つまり数字が320個)なので1280個と合いません。そこで、このベクトルをSiLU+Linearという処理に通して320個に圧縮します。
Linearは学習によって賢くなるパーツです。ここに通すと1280個の数字が320個の数字になって出てきます。
チャンネル数が同じになったので、全部のマス(例えば64x64マス)に対してこのベクトルを足します。
これでデータに「時間情報」が追加されました。
この後、もう一度「Group Norm」「SiLU」「Conv」の処理を行います。
処理後に元データを足す:「Res」という名前の由来
上の「Resブロックの構造」の図をもう一度見てください。
全部の処理をすっ飛ばして、右端に直接つながっている細い矢印があります。これは、「処理し終わったデータに、処理する前のデータを足す」ことを表しています。これこそが「Resブロック」と言われるゆえんです。
ResとはResidual(残りもの)という意味です。
ある処理ブロックがあって、そこにXを入れるとX+dが出てくるとします。つまりその処理ブロックはdの部分を求めていることになります。このdこそがXを取った後の「残り物」です。Resブロックはこれをやっているのです。
これを行う理由はいろいろありますが、一番の理由は「ちゃんと賢くなるため」です。
単純な事ですが、この「足す」という処理はニューラルネット界隈では大発見でした。
まとめ
U-NetのIN1ブロックはさらに細かい2つのブロックからなりますが、そのうちの一つ、「Resブロック」について詳細に見ました。
ここでは正規化、非線形化、畳み込みを2回行い、その間で時間情報を埋め込みます。
これらの処理が終わった後、処理前のデータを足して、次のブロックへと渡します。
次の記事は、IN1の2つ目のコンポーネント「Attentionブロック」について。