TensorFlow2:kerasの継承モデルのSummaryを表示する

前回のMNISTモデルsummary()を呼び出してモデル構造を見ようと思ったエラーになったので調べた件。
(kerasは使ってなかったけど、パラメータ数とか出してくれるのは羨ましかったのよね)

build()call()を実行してモデルを構築すると出力されるようになる。
以下詳細。

エラーの内容

以下のようなコードを書くとエラーが出る。

class MyModel(tfk.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tfkl.Conv2D(32, 3, padding="same",
                                 activation="relu", use_bias=True, name="conv1")
        # ... skip

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        # ... skip
        return self.d2(x)
    model = MyModel()
    model.summary()
    # this will error occored
ValueError: This model has not yet been built. Build the model first by calling `build()` or
calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s)
for automatic build.`

対策(1)

エラーメッセージに素直に従って、build()を呼び出してみる。

    model = MyModel()
    model.build((None, 28, 28, 1))  # build with input shape.
    model.summary()

build()は、入力のshapeを渡す。

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv1 (Conv2D)               multiple                  320
_________________________________________________________________
pool1 (MaxPooling2D)         multiple                  0
_________________________________________________________________
conv2 (Conv2D)               multiple                  18496
_________________________________________________________________
pool2 (MaxPooling2D)         multiple                  0
_________________________________________________________________
flatten (Flatten)            multiple                  0
_________________________________________________________________
fc1 (Dense)                  multiple                  401536
_________________________________________________________________
softmax (Dense)              multiple                  1290
=================================================================
Total params: 421,642
Trainable params: 421,642
Non-trainable params: 0

OK。パラメータ数が出るようになった。これでサイズの計算ができる!

しかし、Output Shapeがmultipleなのが気になるな。

対策(2)

グラフ構築がうまくいってないのかな。
ということで実際にcall()で計算を呼び出してみると、Output Shapeが表示されるようになる。
実データ入れなくても動くのね。

    model = MyModel()
    model.build((None, 28, 28, 1))  # build with input shape.
    dummy_input = tf.keras.Input(shape=(28, 28, 1))  # declare without batch demension.
    model.call(dummy_input)
    model.summary()

call()では、keras.Inputのレイヤを作成してから渡す。
気を付けることは、build()はバッチを含めたshapeにするが、keras.Inputは、バッチを含まない。

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv1 (Conv2D)               (None, 28, 28, 32)        320
_________________________________________________________________
pool1 (MaxPooling2D)         (None, 14, 14, 32)        0
_________________________________________________________________
conv2 (Conv2D)               (None, 14, 14, 64)        18496
_________________________________________________________________
pool2 (MaxPooling2D)         (None, 7, 7, 64)          0
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0
_________________________________________________________________
fc1 (Dense)                  (None, 128)               401536
_________________________________________________________________
softmax (Dense)              (None, 10)                1290
=================================================================
Total params: 421,642
Trainable params: 421,642
Non-trainable params: 0
_________________________________________________________________

うん。十分じゃないかな。

グラフの出力

そういえば、keras.utils.plot_modelでモデルのグラフを画像で出力できるらしい。

Reference: The Keras functional API

Windowsだと、

  • OSにGraphvizをインストールし、graphviz/binにPATHを通す
  • pip installl graphviz
  • pip install pydot-ng
    して、いざ。
    model = MyModel()
    model.build((None, 28, 28, 1))  # build with input shape.
    dummy_input = tf.keras.Input(shape=(28, 28, 1))  # declare without batch demension.
    model.call(dummy_input)
    model.summary()
    tf.keras.utils.plot_model(model, "out/model2.png", show_shapes=True)

ざんねん。肝心の中身がわからない。

githubのissueによると、「subclassedではなく、Functional API」を使ってね、らしい。

Graph visualization of subclassed model/layer
Yes in general we can't assume anything about the structure of a subclassed Model. If your Model can be though of as blocks of Layers and you wish to visualize it like that, we recommend you view the Functional API

先のTensorFlowのページによると、Functional APIは、
keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')
のように、inputsとoutputsを指定して、keras.Modelを直接インスタンス化する方法のようだ。

確かに、tf.keras.Modelを辿っていくと、inpusとoutputsの有無で処理が分かれている。

tensorflow/python/keras/engine/network.py

  def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
    # Signature detection
    if (len(args) == 2 or
        len(args) == 1 and 'outputs' in kwargs or
        'inputs' in kwargs and 'outputs' in kwargs):
      # Graph network
      self._init_graph_network(*args, **kwargs)
    else:
      # Subclassed network
      self._init_subclassed_network(**kwargs)

対策

最初に:
ちょうどハマってる時に記事が上がっていました。
丁寧なので、こっちも見てみてください。今回のはこれを試してみた話。

第5回 お勧めの、TensorFlow 2.0最新の書き方入門(エキスパート向け) (1/2)
https://www.atmarkit.co.jp/ait/articles/2003/10/news016.html

keras.Modelは、レイヤ以外にモデルも内包できるので、inputsとoutputsを指定したModelでラップしてやる。

    model = MyModel()
    model.build((None, 28, 28, 1))  # build with input shape.
    dummy_input = tf.keras.Input(shape=(28, 28, 1))  # declare without batch demension.
    model_summary = tf.keras.Model(inputs=[dummy_input], outputs=model.call(dummy_input))
    model_summary.summary()
    tf.keras.utils.plot_model(model_summary, "out/model3.png", show_shapes=True)

うまく出力できました。

まとめ

  • subclassの場合は、モデルがbuildされていないのでサマリ出力やグラフ出力に制限がかかる
  • サマリ出力を行いたい場合は、build()call()を呼び出す
  • グラフ出力も行いたい場合は、keras.Model(inputs=,outputs=)でラップする

とりあえず対策(2)で十分かな。
複数の入力/出力パスを持つグラフになった場合は、グラフを可視化できると便利かな。

コメント

このブログの人気の投稿

TensorFlow2:TensorBoardのグラフがうまく表示されず困った件

TensorFlow2:TensorBoardのグラフがうまく表示されず困った件(余談)