“ Flatten”在Keras中的作用是什么?

食肉动物:

我试图了解该Flatten功能在Keras中的作用下面是我的代码,它是一个简单的两层网络。它接收形状为(3,2)的二维数据,并输出形状为(1,4)的一维数据:

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

打印出y形状为(1、4)的图形。但是,如果我删除该Flatten行,则会打印出y形状为(1、3、4)的行。

我不明白 从我对神经网络的理解来看,该model.add(Dense(16, input_shape=(3, 2)))函数正在创建一个具有16个节点的隐藏的全连接层。这些节点中的每个都连接到3x2输入元素中的每个。因此,该第一层输出处的16个节点已经“平坦”。因此,第一层的输出形状应为(1、16)。然后,第二层将此作为输入,并输出形状为(1、4)的数据。

因此,如果第一层的输出已经“平坦”并且形状为(1,16),为什么还要进一步使其平坦?

MarcinMożejko:

如果您阅读的Keras文档条目Dense,将会看到以下调用:

Dense(16, input_shape=(5,3))

这样将形成一个Dense具有3个输入和16个输出网络,这些网络将独立应用于5个步骤中的每个步骤。因此,如果D(x)将3维矢量转换为16维矢量,则从图层输出的结果将是一系列矢量:[D(x[0,:]), D(x[1,:]),..., D(x[4,:])]shape (5, 16)为了具有您指定的行为,您可以先将Flatten输入输入到15维向量,然后应用Dense

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

编辑:由于有些人难以理解-在这里,您有一个解释性的图像:

在此处输入图片说明

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章