verbalize

文章書く練習…。

3DGANをchainerで実装した

タイトルの通り、3DGANのchainer実装をgithubに上げた。当初はkerasで書いていたが良い結果が得られず、ソースコードの間違い探しをするモチベーションが下がってきたので、思い切ってchainerで書き直した。

github.com

実はmnistなどのサンプルレベルのものを超えてちゃんとディープラーニングでタスクに取り組むのは今回が初めてだった。Chainerによるgan実装自体は公式のexamplechainer-gan-libが非常に参考になった。

モデル

3DGANはその名の通り3Dモデルを生成するためのGAN。[1610.07584] Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modelingで提案されているもの。前回の記事でも触れた。

構造はDCGANと同様で200次元のベクトルよりGeneratorでサンプルを生成、Discriminatorでデータセット由来かGenerator由来か(real/fake)を分類しそのロスをフィードバックする。

Generatorは以下の図(論文より引用)のようなネットワークで、Discriminatorはこれを反転したようなモデルになっている。最適化手法はAdamで、論文ではDiscriminatorがバッチを8割以上正しく分類できた場合はパラメータを更新しないようにしたとあった。

f:id:bonhito:20171024233301p:plain

3Dモデル

データセットにはShapeNet-v2を用いた。このデータセットには様々な種類の3Dモデルが収録されているが、今回は椅子のモデルのみを抽出した。椅子はおよそ6700サンプルが収録されており、ファイル形式は.binvoxが直接収録されていたのでそれを使用した。

ただ、6700サンプルの3Dデータを全てメモリに乗せることはできなかったため、初期実装では毎回のループで読み込み処理を行っていた。その後、.binvoxファイルのヘッダー読み込みなどが不要であり、処理速度に支障があると感じたので事前に.h5に書き出して使うようにした。

ShapeNet-v2に収録されているデータのサンプルを示す。

f:id:bonhito:20171024234729p:plain:w350f:id:bonhito:20171024234743p:plain:w350f:id:bonhito:20171024234801p:plain:w350f:id:bonhito:20171024234823p:plain:w350

実装

3DGANの実装をやろうと決めてから、すぐにはディープラーニングには手を付けず、3Dモデルの取扱について理解するためのツールをつくっていた。主にはSimple Voxel Viewerで、.binvox形式について理解したり、matplotlibでボクセルをどうやってプロットしようかということについて考えていた。

64x64x64のボクセルを可視化するため、最初はmatplotlibの3Dplotを試したが、scatter plotsurface plotを使うとマインクラフトのような箱を集積した見映えのプロットが実現できない上、一つ描画するのに数十秒かかることがわかった。そこからまず自作してみようと思いTHREE.jsを使ってSimple Voxel Viewerを作ってみた。ところが結局こっちもいくらか高速化は試したものの、64x64x64のサイズでも密なボクセルになるとメモリーエラーが起こってしまいうまく動作しない問題が起こった。加えて当たり前だがPythonのコードにも組み込めない。

そうして結局、matplotlibの3D volxe plotを採用した。しかしこの関数はまだリリースされていないため、githubから直接インストールする必要があった。動作も遅いままだが妥協することにした。

ネットワークはKerasやTensorflowなどによる実装がいくつかgithubに上がっていたためそれらも参考にしつつ実装した。加えて、有名なGANのベストプラクティスのページを参考にした。ポイントをかいつまむと以下のような感じで実装した。

  • ランダムベクトルzは論文では一様分布だったがガウス分布を使った。
  • GeneratorはDeconv3D+BN+ReLUの繰り返しで、最後だけsigmoid
  • DiscriminatorはConv3D+BN+Leaky-ReLUの繰り返しで、最後だけsigmoid
  • Chainerの公式のexampleを真似してロスはsoftplusを使って実装。ただ、実はsigmoid + Adversarial losssoftplusと同じなのでDiscriminatorの最後のsigmidは不要なのだが、加えた方がうまくいった(謎)。

結果

成功例

良さげな感じを出すためにきれいなものを集めた。学習の初期段階ではでたらめなものが出力されるが、徐々に椅子が形成され、50エポックから100エポックくらいでましなものが出来た。

学習の途中では椅子とは独立した無意味なかたまりのオブジェクトが所々に浮かんでいたりしたが、それが消えてくるとかなり見栄えが良くなっていった。

https://github.com/piyo56/3dgan-chainer/blob/master/result/generated_samples/png/7.png?raw=true https://github.com/piyo56/3dgan-chainer/blob/master/result/generated_samples/png/13.png?raw=true https://github.com/piyo56/3dgan-chainer/blob/master/result/generated_samples/png/21.png?raw=true https://github.com/piyo56/3dgan-chainer/blob/master/result/generated_samples/png/30.png?raw=true

失敗例

ボクセルが全て1になったり0になって消滅したりした。今回幾度も学習をさせてみて、初期の段階からほぼ1なボクセル、あるいはほぼ0なボクセルが生成されたり、規則的なパターン(模様)を持つボクセルが生成されたりすると多くの場合失敗となるという微妙な知見を得た。

また、ボクセルが消滅したらその後復活しないこともわかった。ただこれは実装のところで述べたようにロスが間違っているせいかもしれない。

https://i.gyazo.com/1c516e331073f2f35c20948f7e5358b0.png https://i.gyazo.com/c12285bb038635a17df18410202bc315.png https://i.gyazo.com/0b87e9115d56faa93332afa61d093ad8.png https://gyazo.com/1e67d6f8872df8b0b6e41d01021f69bf.png

わかったこと

  • GANはロスは全くあてにならない。生成結果が全て。
  • zガウス分布から取ってきたほうが良さそう。
  • 学習を調整する(Discriminatorのlossやaccを見て更新しないなど)のはうまくいかないと感じた。
  • 今回のコードではsigmoid + adversarial losssoftplusで実装しているので、Discriminatorの最後のsigmoidは不要なはずなのだが、誤って入れていたらうまくいき、外したらうまくいかなくなった。動きゃ勝ちみたいなところがあって釈然としない。
  • 論文では1000エポック学習したとあったが100エポック行かないくらいでかなり形になった。

また、今回はGANのベストプラクティスの内、以下のトリックは実践しても効果がなかった.

  • Discriminatorに学習させるミニバッチをrealのみまたはfakeのみにする。(リンク元項目4)
  • GeneratorにもDiscriminatorにもLeaky-ReLUをつかう。(リンク元項目5)
  • GeneratorにADAMを使ってDiscriminatorにはSGDを使う。(リンク元項目10)
  • GeneratorにDropoutを使う。(リンク元項目17)

所感

実装に関して、やはりコード自体はkerasの方が圧倒的に簡単にかけるようになっているなと感じた。モデルのインスタンスを作ってバッチをmodel.train_on_batchに渡す処理を繰り返せば良いというイメージでとてもシンプルで良い。dcganの実装についても、①Descriminatorのみのモデル、②Descriminator(not trainable)+Generatorのモデルをそれぞれ定義することで学習を実現したが、それぞれ独立に学習、更新処理をやるよというのが表現しやすくてわかりやすかった。

一方でchainerはまずGPUがある環境とない環境でコードを分けなくてはならず、そこがまず面倒に感じた。自分はリモートマシンが研究室のマシンやGCPなどさまざまなので、手元のmacbookで実装してrsyncでデプロイという形で実装していたので処理をいちいち分岐させるのは面倒に感じた。

加えてtrainer, updater辺りの構造を理解するのに学習コストが要るなと感じた。extensionもロギングなどよく使う処理をパッケージ化できる所はメリットだと思うが、初めて書く人間にとってはさして複雑でない部分が隠蔽され逆にとっかかりづらいと感じる。ただ慣れてしまえばdcganの実装もkerasと同じくらいわかりやすくはあったのでこれからも忘れないように使っていきたいと思った。

今回ディープラーニングをやってみて、モデルの定義は簡単な反面、ハイパーパラメータのチューニングがべらぼうに難しいという印象だった。あと時間がかかる。今回の3DGANではGoogle Cloud Platform(GCP)を使っていて、2.2 GHz Intel Xeon E5 v4 (Broadwell) x 4、メモリ 16 GB、NVIDIA Tesla K80 x 4という環境だったのですが1エポック15分くらいかかった。1回50エポックで試すとしても一晩寝ても終わってない感じ。

そして、今回は自分なりのディープラーニングのやる方を考える良い機会にもなった。複数リモートがある場合のコードのホスティングは最初はgitを使っていたが、どうしても小さい変更が頻繁に入るのでrsyncを使ってデプロイするようにした。データセットは面倒だが毎回リモートでダウンロードした。

GCPは最初はどのタイプのマシンを使えば良いか慣れなかったのでGPUを使えるようにする環境構築をシェルスクリプトで自動化した。またGANは生成結果が重要なので、Chainerのextensionとしてエポック毎に生成結果を.binvoxで保存生成結果を.pngで保存生成結果(.png)をslackで通知するようにして、手元のスマホで確認するようにしていた。特に生成結果を画像で保存する場合はmatplotlibの描画の処理にかなり時間がかかるため、そこで学習が止まらないようにsubprocessを使って別プロセスで非同期でやるようにしたところが楽しかった。

f:id:bonhito:20171025201323p:plain:w350

おしまい。