如何创建简单的三层神经网络并使用监督学习进行教学?

路加

基于PyBrain的教程,我设法将以下代码组合在一起:

#!/usr/bin/env python2
# coding: utf-8

from pybrain.structure import FeedForwardNetwork, LinearLayer, SigmoidLayer, FullConnection
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import BackpropTrainer

n = FeedForwardNetwork()

inLayer = LinearLayer(2)
hiddenLayer = SigmoidLayer(3)
outLayer = LinearLayer(1)

n.addInputModule(inLayer)
n.addModule(hiddenLayer)
n.addOutputModule(outLayer)

in_to_hidden = FullConnection(inLayer, hiddenLayer)
hidden_to_out = FullConnection(hiddenLayer, outLayer)

n.addConnection(in_to_hidden)
n.addConnection(hidden_to_out)

n.sortModules()

ds = SupervisedDataSet(2, 1)
ds.addSample((0, 0), (0,))
ds.addSample((0, 1), (1,))
ds.addSample((1, 0), (1,))
ds.addSample((1, 1), (0,))

trainer = BackpropTrainer(n, ds)
# trainer.train()
trainer.trainUntilConvergence()

print n.activate([0, 0])[0]
print n.activate([0, 1])[0]
print n.activate([1, 0])[0]
print n.activate([1, 1])[0]

本来应该学习XOR函数,但结果似乎是随机的:

0.208884929522

0.168926515771

0.459452834043

0.424209192223

要么

0.84956138664

0.888512762786

0.564964077401

0.611111147862

BartoszKP

您的方法有四个问题,在阅读《神经网络常见问题》后都可以轻松找到

  • 为什么要使用偏置/阈值?:您应该添加一个偏置节点。偏见的缺乏使学习非常有限:网络代表的分离的超平面只能通过原点。使用bias节点,它可以自由移动并更好地拟合数据:

    bias = BiasUnit()
    n.addModule(bias)
    
    bias_to_hidden = FullConnection(bias, hiddenLayer)
    n.addConnection(bias_to_hidden)
    
  • 为什么不将二进制输入编码为0和1?:所有样本都位于样本空间的一个象限中。移动它们使其分散在原点周围:

    ds = SupervisedDataSet(2, 1)
    ds.addSample((-1, -1), (0,))
    ds.addSample((-1, 1), (1,))
    ds.addSample((1, -1), (1,))
    ds.addSample((1, 1), (0,))
    

    (相应地,将验证码固定在脚本的末尾。)

  • trainUntilConvergence该方法使用验证工作,并且执行类似于早期停止方法的操作对于这么小的数据集,这没有任何意义。使用trainEpochs代替。1000对于这个问题,时代对于网络而言已经足够了:

    trainer.trainEpochs(1000)
    
  • 反向传播应使用哪种学习率?:调整学习率参数。每当您使用神经网络时,便会执行此操作。在这种情况下,该值0.1甚至会0.2大大提高学习速度:

    trainer = BackpropTrainer(n, dataset=ds, learningrate=0.1, verbose=True)
    

    (请注意verbose=True参数。调整参数时,观察错误的行为至关重要。)

有了这些修复程序,我就可以针对具有给定数据集的给定网络获得一致且正确的结果,并且误差小于1e-23

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章