用 tensorflow 1 时,想要多次重复实验取平均,在两次实验之间需要清一次计算图,否则会报错说 xx 变量重复定义。代码形式:
# import tensorflow
class MyModel:
def __init__(self):
# build model
def train(self):
with tf.Session() as sess:
# training
# 多 runs 取平均
for i_run in range(n_run):
model = MyModel()
model.train()
# 两 runs 之间清一次计算图
tf.reset_default_graph()
print("DONE")