如何在tensorflow占位符中使用get_collection

守护者121

因此,我在Feed变量方面遇到了一些问题。我希望冻结模型的权重和偏差。我有下一个变量:

wc1 = tf.Variable(tf.random_normal([f1, f1, _channel, n1], mean=0, stddev=0.01), name="wc1")
wc2 = tf.Variable(tf.random_normal([f2, f2, n1, n2], mean=0, stddev=0.01), name="wc2")
wc3 = tf.Variable(tf.random_normal([f3, f3, n2, _channel], mean=0, stddev=0.01), name="wc3") 

bc1 = tf.Variable(tf.random_normal(shape=[n1], mean=0, stddev=0.01), name="bc1")
bc2 = tf.Variable(tf.random_normal(shape=[n2], mean=0, stddev=0.01), name="bc2")
bc3 = tf.Variable(tf.random_normal(shape=[_channel], mean=0, stddev=0.01), name="bc3")

例如,我要在前10个时期训练[wc1,bc1],然后在下一个时期训练[wc2,bc2],依此类推。为此,我创建了变量集合:

tf.add_to_collection('wc1', wc1)
tf.add_to_collection('wc1', bc1)

tf.add_to_collection('wc2', wc2)
tf.add_to_collection('wc2', bc2)

并为集合名称创建占位符:

trainable_name = tf.placeholder(tf.string, shape=[])

接下来,我尝试在优化器中获取它:

opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = opt.minimize(cost, var_list=tf.get_collection(trainable_name))

供稿数据:

sess.run(train_op, feed_dict={ ... , trainable_name: "wc1"})

我得到错误:

 Traceback (most recent call last):
  File "/home/keeper121/PycharmProjects/super/sp_train.py", line 292, in <module>
    train(tiles_names, "model.ckpt")
  File "/home/keeper121/PycharmProjects/super/sp_train.py", line 123, in train
    train_op = opt.minimize(cost, var_list=tf.get_collection(trainable_name))
  File "/home/keeper121/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 193, in minimize
    grad_loss=grad_loss)
  File "/home/keeper121/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 244, in compute_gradients
    raise ValueError("No variables to optimize")
ValueError: No variables to optimize

那么,有什么方法可以在整个会话期间更改训练变量?

谢谢。

菲利普·博克(Phillip Bock)

尝试以下操作:

train_op_wc1 = opt.minimize(cost, var_list=tf.get_collection("wc1"))
train_op_wc2 = opt.minimize(cost, var_list=tf.get_collection("wc2"))

然后,当您输入数据时:

#define your samples as you would always do
input_feed = ...
#then use the training op that addresses the correct layers, as you defined above
if first_10_epoch:
  sess.run(train_op_wc1, feed_dict=input_feed)
else:
  sess.run(train_op_wc2, feed_dict=input_feed)

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

如何在 dplyr 中使用占位符?

如何在Angular Material 7拖放中使用占位符?

如何在python字典中使用占位符?

如何在.yml文件中使用属性占位符

如何在我的 Formbuilder (Symfony 4) 中使用占位符?

如何在for循环中使用占位符?

如何在输入占位符中使用FormattedMessage

如何在intl.formatMessage中使用占位符

如何在 jinja2 模板中使用占位符?

如何在 HTML 中使用多个不同颜色的占位符?

如何在 JSON (RedBeanPHP) 中使用占位符?

当有其他占位符时,如何在IN()子句中使用占位符?

无法在tensorflow中使用占位符初始化变量

在 tensorflow 中使用公共输入对占位符进行分组

在Tensorflow中使用2D占位符遮罩3D占位符张量

如何在TensorFlow中遍历可变长度的占位符?

如何在keras / tensorflow中为占位符提供值

如何在Tensorflow中复制操作和占位符

React-Intl如何在输入占位符中使用FormattedMessage

如何在颤动中使用 Image.network() 放置图像占位符?

如何在Flyway命令行中使用占位符前缀

如何在Java文本块中使用变量值的占位符

如何在XML中使用Gradle占位符applicationId来获得唯一的contentProvider权限

如何在文件系统上的文件中使用属性占位符

如何在骆驼处理器中使用属性占位符

如何在PyCharm替换功能中使用正则表达式占位符

在bash printf格式中如何在muliples占位符中使用相同的值

如何在 Jetpack Compose 中使用图像占位符进行预览

如何在pyrocms中使用流API为流字段的输入定义占位符?