誰でもわかるStable Diffusion その10:U-Netの各ブロックの働き
前の記事まで、Stable Diffustionで使われているU-Netがどういう仕組みで画像データを処理しているのかをずっと見てきました。
一通り説明が終わったので、U-Netの全体像と、各ブロックが何をしているのかを少し詳しく眺めていきます。
U-Netについては以前少し紹介したのでその記事も参考にしてください。
注意
ここで説明するU-Netの構造はStable Diffusion独自のものです。
U-Netが必ずこういう形をしているわけではありません。
U-Netは3ブロック、5段階構造
Stable DiffusionのU-Netの構造は下の図の通りです。
上下で見ると、画像データのサイズの違いから4段構えになっています(下の2段は画像サイズが同じなので、これらを1段とカウントしています)。
左右で見ると、以下の3ブロックに分かれます。
- 画像を小さくしながら特徴を取り出していく「INブロック」(エンコーダーと呼ばれることもあります)
- 最小画像、最大特徴数を処理する「MIDブロック」(ボトルネックと呼ばれることもあります)
- 画像を大きくしながら特徴をまとめていく「OUTブロック」(デコーダーと呼ばれることもあります)
IN、MID、OUTもそれぞれ
- IN0からIN11までの12ブロック
- MID
- OUT0からOUT11までの12ブロック
の合計25ブロックに分かれます。
さて、上下の各段を見ると、INでは基本的に1段が2つのINブロックで構成されていることが分かります。OUTでは1段が3つのOUTブロックで構成されます。
つまり、INでは2つのブロックが同じサイズの画像データを処理します。OUTでは3つのブロックが処理をします。
IN、MID、OUTの各ブロックを見てみると、さらに処理を細かく分けることができます。
例えばIN1を見てみると、「Resブロック」と「Attention」ブロックの2つからできています。
MIDブロックは「ResーAttentionーRes」というサンドイッチ構造をしています。
下段のIN10、IN11、OUT0、OUT1、OUT2では「Resブロック」しかありません。
このように、ブロックによって処理も違います。
各ブロックの処理と特徴
さて、ここからは各ブロックの主な特徴を見ていきます。
注意
- 分かりやすくするために、以下の説明で処理される画像データのサイズは「タテ64xヨコ64」とします。
- Attentionブロックで画像に反映されるトークン(テキスト単語)の影響の強さは「弱、中、強」で表します。
- トークンが画像中のどれくらいの範囲に反映されるかは「全体的、やや微細、微細構造、超微細構造」で表します(ただしこの範囲は画像やトークン内容によって変わる可能性があります)。
INブロックの処理
IN0
処理:畳み込みのみ
入力サイズ:タテ64xヨコ64x特徴4
出力サイズ:タテ64xヨコ64x特徴320
IN1
処理:Resブロック、Attentionブロック
入力サイズ:タテ64xヨコ64x特徴320
出力サイズ:タテ64xヨコ64x特徴320
トークンの影響:弱、全体的
IN2
処理:Resブロック、Attentionブロック
入力サイズ:タテ64xヨコ64x特徴320
出力サイズ:タテ64xヨコ64x特徴320
トークンの影響:中、全体的
IN3
処理:畳み込みのみ(サイズ縮小)
入力サイズ:タテ64xヨコ64x特徴320
出力サイズ:タテ32xヨコ32x特徴640
IN4
処理:Resブロック、Attentionブロック
入力サイズ:タテ32xヨコ32x特徴640
出力サイズ:タテ32xヨコ32x特徴640
トークンの影響:強、全体的
IN5
処理:Resブロック、Attentionブロック
入力サイズ:タテ32xヨコ32x特徴640
出力サイズ:タテ32xヨコ32x特徴640
トークンの影響:弱、微細構造
IN6
処理:畳み込みのみ(サイズ縮小)
入力サイズ:タテ32xヨコ32x特徴640
出力サイズ:タテ16xヨコ16x特徴1280
IN7
処理:Resブロック、Attentionブロック
入力サイズ:タテ16xヨコ16x特徴1280
出力サイズ:タテ16xヨコ16x特徴1280
トークンの影響:中、全体的~やや微細
IN8
処理:Resブロック、Attentionブロック
入力サイズ:タテ16xヨコ16x特徴1280
出力サイズ:タテ16xヨコ16x特徴1280
トークンの影響:中、全体的
IN9
処理:畳み込みのみ(サイズ縮小)
入力サイズ:タテ16xヨコ16x特徴1280
出力サイズ:タテ8xヨコ8x特徴1280
IN10
処理:Resブロック
入力サイズ:タテ8xヨコ8x特徴1280
出力サイズ:タテ8xヨコ8x特徴1280
IN11
処理:Resブロック
入力サイズ:タテ8xヨコ8x特徴1280
出力サイズ:タテ8xヨコ8x特徴2560
MIDブロックの処理
MID
処理:Resブロック、Attentionブロック、Resブロック
入力サイズ:タテ8xヨコ8x特徴2560
出力サイズ:タテ8xヨコ8x特徴1280
トークンの影響:強、全体的~やや微細
OUTブロックの処理
OUT0
処理:Resブロック
入力サイズ:タテ8xヨコ8x特徴1280+IN11出力(タテ8xヨコ8x特徴1280)
出力サイズ:タテ8xヨコ8x特徴1280
OUT1
処理:Resブロック
入力サイズ:タテ8xヨコ8x特徴1280+IN10出力(タテ8xヨコ8x特徴1280)
出力サイズ:タテ8xヨコ8x特徴1280
OUT2
処理:Resブロック
入力サイズ:タテ8xヨコ8x特徴1280+IN9出力(タテ8xヨコ8x特徴1280)
出力サイズ:タテ16xヨコ16x特徴1280
OUT3
処理:Resブロック、Attentionブロック
入力サイズ:タテ16xヨコ16x特徴1280+IN8出力(タテ16xヨコ16x特徴1280)
出力サイズ:タテ16xヨコ16x特徴1280
トークンの影響:弱、微細構造
OUT4
処理:Resブロック、Attentionブロック
入力サイズ:タテ16xヨコ16x特徴1280+IN7出力(タテ16xヨコ16x特徴1280)
出力サイズ:タテ16xヨコ16x特徴640
トークンの影響:強、微細構造
OUT5
処理:Resブロック、Attentionブロック
入力サイズ:タテ16xヨコ16x特徴640+IN6出力(タテ16xヨコ16x特徴1280)
出力サイズ:タテ32xヨコ32x特徴1280
トークンの影響:中~強、やや微細
OUT6
処理:Resブロック、Attentionブロック
入力サイズ:タテ32xヨコ32x特徴1280+IN5出力(タテ32xヨコ32x特徴640)
出力サイズ:タテ32xヨコ32x特徴640
トークンの影響:強、超微細構造
OUT7
処理:Resブロック、Attentionブロック
入力サイズ:タテ32xヨコ32x特徴640+IN4出力(タテ32xヨコ32x特徴640)
出力サイズ:タテ32xヨコ32x特徴320
トークンの影響:強、微細構造(対象トークンに関連する部分?)
OUT8
処理:Resブロック、Attentionブロック
入力サイズ:タテ32xヨコ32x特徴320+IN3出力(タテ32xヨコ32x特徴640)
出力サイズ:タテ64xヨコ64x特徴640
トークンの影響:強、全体的~微細構造(対象トークンに関連する部分?)
OUT9
処理:Resブロック、Attentionブロック
入力サイズ:タテ64xヨコ64x特徴640+IN2出力(タテ64xヨコ64x特徴320)
出力サイズ:タテ64xヨコ64x特徴320
トークンの影響:強、全体的
OUT10
処理:Resブロック、Attentionブロック
入力サイズ:タテ64xヨコ64x特徴320+IN1出力(タテ64xヨコ64x特徴320)
出力サイズ:タテ64xヨコ64x特徴320
トークンの影響:弱、全体的
OUT11
処理:Resブロック、Attentionブロック
入力サイズ:タテ64xヨコ64x特徴320+IN0出力(タテ64xヨコ64x特徴320)
出力サイズ:タテ64xヨコ64x特徴4
トークンの影響:中、全体的~やや微細
U-Netが作り出すもの
前の記事でも解説したとおり、U-Netは「ノイズだらけの絵」を取り込んで、その画像をもとに「絵に乗っているノイズ」を予想して、ノイズだけを出力します。
つまり、OUT11ブロックが最終的に出力するものは「予想ノイズ画像」です。
U-Netが出力した「予想ノイズ」を元の「ノイズだらけの絵」から引けば、「ノイズのない絵」ができる、という仕組みです。
もしU-Netが予想して作った「ノイズ画像」が完全に正しいなら、ノイズだらけの絵からこのノイズ画像を引けば一発で完全に正しい絵ができるはずです。
しかし、予想が完全に正しいということはあり得ません。そこで、この「予想ノイズ画像」を弱めて(薄めて)、元の絵から少しだけノイズを引くようにします。
元の絵にはまだノイズが残っているので、またU-Netに突っ込みます。U-Netはまた乗っているノイズを予想し、出力します。
こうしてU-Netの処理を繰り返して絵を完成させていきます。
この「ノイズを予想する」処理を「サンプリング」と言い、予想を繰り返す回数を「サンプリングステップ」と言います。
上で「ノイズだらけの絵から予想ノイズを引く」と書きましたが、「画像生成のどの段階でどれだけノイズを引くか」を決めるのが「サンプラー」(または「スケジューラー」)と呼ばれるモジュールです。
あるサンプラーでは生成の最初から最後まで毎回ほぼ一定量のノイズを引き続けます。
他のサンプラーでは生成初期の段階ではノイズをごっそりと引き、絵からノイズが少なくなったら今度は少しずつノイズ引いていきます。
生成方法が違うので、サンプラーによってできる絵が微妙に違います。
(サンプラーについてはまた後日機会があれば解説します)
まとめ
U-Netの各ブロックが何をしているのかを全体的に見ました。
IN、MID、OUTの3パートに分かれていて、INでは画像を縮小しながら特徴抽出、OUTでは画像を拡大しながら特徴をもとに新たな画像を作り出していく働きがあります。
各ブロックは一見似たような処理を行っているように見えて、役割はそれぞれ微妙に違うようです。
画像データはこのU-Netを通り抜けて、「予想ノイズ画像」に変換されます。
これを使って元の画像データから少しずつノイズを取り除き、絵を完成させるのです。
U-Netの解説はこれでおしまいです。