TensorFlow2:同一スクリプト内で複数回の学習を回す

パラメータを振った複数の学習を、一気に流したいときにどうするかという話。

エラー落ちする対策をやってみたが、最後に紹介するやり方が素直な気がする。

実行するモデル

今回使うモデルはこちら。weight decayをパラメータでとるようにしている。

import tensorflow as tf

# asign aliases
tfk = tf.keras
tfkl = tf.keras.layers


class MyModel2(tfk.Model):
    def __init__(self, wd=0.001):  # add `weight decay` parameter.
        super(MyModel2, self).__init__()
        self.conv1 = tfkl.Conv2D(32, 3, padding="same", activation="relu", use_bias=True,
                                 kernel_regularizer=tfk.regularizers.l2(wd),
                                 name="conv1")
        self.pool1 = tfkl.MaxPool2D(name="pool1")
        self.conv2 = tfkl.Conv2D(64, 3, padding="same", activation="relu", use_bias=True,
                                 kernel_regularizer=tfk.regularizers.l2(wd),
                                 name="conv2")
        self.pool2 = tfkl.MaxPool2D(name="pool2")
        self.flatten = tfkl.Flatten()
        self.d1 = tfkl.Dense(128, activation="relu",
                             kernel_regularizer=tfk.regularizers.l2(wd),
                             name="fc1")
        self.d2 = tfkl.Dense(10, activation="softmax", name="softmax")

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

素直に2回学習を実行してみる(エラーが起きる)

見出しの通り。そのままやると以下エラーが起きます。

ValueError: tf.function-decorated function tried to create variables on non-first call.

実行部分はこんな感じ。端折ってるので、適当に補完して読んでください。

    # ...skip

    print("loading dataset")
    ds_train, ds_test = load_dataset()

    WD = 0.01
    print("building model with wd={}".format(WD))
    model = MyModel2(wd=WD)
    optimizer = tf.keras.optimizers.Adam()
    training(model, optimizer, args.epochs, ds_train, ds_test, logdir)

    # second time in same script will rise error:
    # ValueError: tf.function-decorated function tried to create variables on non-first call.

    WD = 0.001
    print("building model with wd={}".format(WD))
    model = MyModel2(wd=WD)
    optimizer = tf.keras.optimizers.Adam()
    training(model, optimizer, args.epochs, ds_train, ds_test, logdir)

エラーの対策を行う

@tf.function を指定している関数が、1回目と2回目で違うことが問題になっているよう。

@tf.functionの関数と、モデルは、それぞれ初回実行時にビルドするけど、関数は同じだけどモデル側だけ再生成がかかって、それをエラーとしてフックしていると解釈した。

いつもどおりissueを漁ってみる。あったあった。

tf.function-decorated function tried to create variables on non-first call #27120

以下のコメントで暫定回避策が記載されている。
https://github.com/tensorflow/tensorflow/issues/27120#issuecomment-593001572

適当に要約すると:

  • もともとの@tf.functionの関数を返すラッパー関数を作って
  • 学習部ではいきなり@tf.functionの関数を呼び出してみて
    • UnboundLocalError: そんな関数ないよ(スコープ外)エラーが起きたらラッパーで関数を生成
    • ValueError: 今回当たった、モデルが変わったよエラーが起きたらラッパーで関数を生成
  • 同一モデルの2回目以降の呼び出しは、先に生成した@tf.functionの関数を呼ぶ

トリッキーだ。やってみる。

def step_train_wrapper(model, optimizer, loss_obj):
    # avoid `tf.function-decorated function tried to create variables on non-first call.`
    # this is for run multiple graph in single file.
    # https://github.com/tensorflow/tensorflow/issues/27120
    @tf.function
    def step_train(images, labels, metr_loss, metr_acc):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_obj(labels, predictions)
            gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        metr_loss(loss)
        metr_acc(labels, predictions)
        return loss
    return step_train
def training(model, optimizer, epochs, ds_train, ds_test, logdir=None):
    # ...skip
    for epoch in range(1, epochs+1):
        # do training.
        for i, (images, labels) in enumerate(ds_train, 1):
            try:
                loss_step = step_train(images, labels, metr_train_loss, metr_train_acc)  # type: ignore # noqa
            except (UnboundLocalError, ValueError):
                step_train = step_train_wrapper(model, optimizer, loss_obj)
                loss_step = step_train(images, labels, metr_train_loss, metr_train_acc)
    # ...skip

確かに。うまく動く。

ちなみに、連続して学習を実行すると、Eager効果でレイヤ名など重複して自動で「_1」とか付与されることがある。
(ということは、レイヤ名以外にも副作用があるような気がする)

以下のように学習ごとにセッションをリセットしておくと避けることができるようだ。

    # reset graph (avoid auto renaming (add suffix) for model/layer)
    tf.keras.backend.clear_session()

別のやり方

今回は、同じデータセット使うので、読み込みを端折れないかというモチベーションだったが、
そこはあきらめて素直にスクリプトごと実行するようにしてみる。

パラメータに基づいて1パターンの学習を行うスクリプトと、パラメータを振るスクリプトの、2ファイル構成。

まずは学習を行うスクリプト。
train, testのwrapperを戻した後、weight decayパラメータを引数で受け取るようにする。

import argparse

# ...skip

if __name__ == "__main__":
    # parse args.
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument("--wd", dest="wd", metavar="weightdecay",
                        type=float, default=0.01,
                        help="weight decay in layer")
    parser.add_argument("--logdir", dest="logdir",
                        help="directory to store log and model (auto create subdir using datetime)")
    parser.add_argument("epochs",
                        type=int,
                        help="training epoch")
    args = parser.parse_args()

    print("loading dataset")
    ds_train, ds_test = load_dataset()

    print("building model with wd={}".format(args.wd))
    model = MyModel2(wd=args.wd)
    optimizer = tf.keras.optimizers.Adam()
    if args.logdir:
        logdir = make_log_path(args.logdir, "wd-{}".format(args.wd))
    else:
        logdir = None
    training(model, optimizer, args.epochs, ds_train, ds_test, logdir)

    pass

次に、パラメータを振るスクリプト。
training-multi2.pyは、上で作成した学習を行うスクリプトの名前。

import subprocess

LOGDIR = "out"
EPOCHS = 3
WD_PATTERN = [
    0.01, 0.001,
]

if __name__ == "__main__":
    param_logdir = "--logdir={}".format(LOGDIR)
    param_epochs = "{}".format(EPOCHS)
    for weight_decay in WD_PATTERN:
        param_wd = "--wd={}".format(weight_decay)
        subprocess.run([
            "python", "training-multi2.py",
            param_wd, param_logdir, param_epochs
        ])

    pass

subprocess.run()の引数は最低限の指定なので、使い方によっては検討必要かも。

これでも、もともとやりたかったパラメータを振って実行は実現できる。

まとめ

  • 同一スクリプトで複数回学習を実行するには、Hack的なコードが必要
    • @tf.functionの関数を再生成させる必要がある
  • けど副作用が怖いので、素直にスクリプトごと実行しなおしたほうが安全な気がする

データセットの使いまわしを考えたけど、大規模かつ全部メモリに乗せるような使い方でないと、気にしなくて良いかなと思い直した。
それより副作用が怖い。

コメント

このブログの人気の投稿

TensorFlow2:TensorBoardのグラフがうまく表示されず困った件

TensorFlow2:kerasの継承モデルのSummaryを表示する

TensorFlow2.1+VSCodeで補完が効かない件