誰でもわかるStable Diffusion その7:U-Net(テキストを画像に反映するAttentionブロック)
今回はStable Diffusionがどうやってテキストを画像に反映させているのかを見ていきます。
重要な役割を担うのは「Attention」です。これには「Transformer」というメカニズムが使われていますが、これは少し複雑なので今回は詳細は省いて、大まかに何をしているかだけを説明します。
Attentionブロック:テキストを画像に反映
Stable Diffusionの心臓部、U-Netの構造を再び見てみましょう。
ところどころにオレンジの部分があります。これが「テキストを画像に取り込む」ブロック、「Attentionブロック」です。
Attentionブロックこそが人間とAIのコミュニケーションを担う場所です。
このAttentionブロックがないと、Stable Diffusionにどんな絵を描きたいかを伝えることができません。
指示がないと、Stable Diffusionは知っている絵をランダムに適当に描きます。それでは使い物になりませんね。
ここでいう「描く」とは、「ノイズ画像からノイズを取り除いて知っている絵に近づける」作業のことです。これはResブロックが行います。
Attentionブロックはすべてのブロックにあるわけではありません。
テキスト情報を取り込むブロックは、
IN1、IN2、IN4、IN5、IN7、IN8
MID
OUT3、OUT4、OUT5、OUT6、OUT7、OUT8、OUT9、OUT10、OUT11
の計16ブロックです。
特に後半(OUTブロック)に何度もテキストを反映させて、画像をテキスト通りに描かせようとしている様子がうかがえます。
Attentionブロックの概要
まずはAttentionブロックが何をしているのかを大まかに把握します。
入力準備だの出力準備だの後処理だのは単なる「データ変換」で、これは本質ではありません。
このブロックで重要なのは「比較」処理です。「Attention」と呼ばれます。
Attentionとは
比較処理は2回行われます。
- 画像内のそれぞれのマスが他のマスとどう関係しているか
- 画像内のそれぞれのマスが与えられたテキスト(トークン)とどう関係しているか
前者を「Self Attention」、後者を「Cross Attention」と呼びます。
Attentionとは「注意を向ける」という意味です。「Self」は「自分自身」、「Cross」は「自分以外」との比較を表します。
もちろんやっているのは比較だけではありません。比較した結果、より関係が深いと思われるデータを自分に取り込んでいきます。
例えば、あるマスが「black」というテキスト(トークン)に深く関係していると判断された場合、そのマスのデータに「black」由来のデータを多く取り込みます。
こうして元のデータが「関係の深い情報を取り込んだデータ」に変換されます*1。
茶色の矢印は、「データを足す」ことを表しています。つまり、データに何か処理を行った後は、必ず処理を行う前のデータを足します。これはAttentionによって変換されてしまった前の情報も忘れないためです。「ちゃんと賢くなる」ための工夫でもあります。
Attentionブロックの構造
Attentionブロックが何をしているのか大まかに分かったので、中身を少し詳しく見ていきます。
Self AttentionとCross Attentionについては上で概要を説明したのでここでは省きます。
Normalization:正規化
まずあちこちで見られる「Norm」について。
Normはデータの正規化を行う処理です。数字を足したり掛けたりするとどんどん数が大きくなって制御しにくくなるので、正規化によって常に制御範囲内に収めるようにしています。
ここではGroup NormとLayer Normの2種類が使われています。
Group Norm:グループごとに正規化
ブロック最初に現れる「Group Norm」は、Resブロックでも出てきました。
これは320チャンネル*2を10個のグループに分けて、そのグループ内でそれぞれ正規化する処理です。
Layer Norm:チャンネル内で正規化
Layer Normは文字通り「レイヤー(チャンネル)ごとの正規化」、つまりそれぞれのチャンネル内で正規化を行います。
Linear:データを賢く変換
次に「Linear」というブロックですが、これはただのニューラルネットワークです。
その役割は「データを変換する」ことですが、この「変換」は学習によって賢くなります。
それぞれのブロックの最後にLinearブロックを入れることで、賢く変換されたデータが出力されるようになります。
Conv:1x1の畳み込み
最初と最後のブロックにある「Conv」ブロックは、畳み込みを行っています。
畳み込みについての詳細は過去の記事を参考にしてください。
Resブロックの畳み込みと違って、ここのフィルターの大きさは1x1マスです(Resブロックのフィルターは3x3)。
1x1マスの畳み込みは普通、チャンネルの数を変換する(減らす)のが目的ですが、実際はAttentionブロックで使われるConvでは処理前後でチャンネル数は変わりません。
つまり、320チャンネルのデータを入れると320チャンネルそのままのデータが返ってきます。ただしサイズは変わりませんがデータの中身は変わっています。
本質的にはConvもLinearの一種です。
GeGLu:非線形要素を追加
最後にGeGLuですが、これは非線形変換です。「非線形」というのはLinear変換に加えるスパイスのようなもので、これがあると変換が複雑になる、というくらいの理解で構いません。
画像のような複雑なデータを扱うので、複雑な変換ができる機能を追加しています。
GeGLuとはGELUの一種です。
GELUは「マイナスの値を受け取ったらはほぼ0を返し、プラスの値はほぼそのままを返す」関数です。
GeGLuでは、まず入力データが2つに複製されます。
- 片方はGELUに通しますが、その前にLinear処理を行います。つまりニューラルネットに通します。
- もう片方にも別のLinear処理を行いますが、こちらはGELUを通りません。
最後に、GELUに通したデータと通さなかったデータを重ね、重なった所を掛け合わせます。
複雑に見えますが、やっていることはただの非線形化です。
まとめ
Attentionブロックは、ユーザーが入力したテキストを画像に反映させる仕組みです。これがないとStable Diffusionは好き勝手に絵を描いてしまいます。
AttentionブロックはU-Netに複数個あり、画像生成途中でテキストが何度も反映されていることが分かります。
重要なのは画像のマス同士の関係を比較する「Self Attention」、画像のマスと与えられたテキスト(トークン)を比較する「Cross Attention」です。これらが画像に情報を取り込み、テキストを反映させていきます。
Attentionを行うメカニズムである「Transformer」についてはまた別記事で。