由于这段代码中的某些内容,我的Google Colab一直崩溃。不知道是什么

超级机器人

我的Google colab一直在火车上崩溃,即使RAM和磁盘很多。我认为问题出在代码中的某个地方,但我不知道它是什么。我在做LSTM。我将不胜感激。我正在使用PyTorch。

此代码后跟生成函数,编码器,解码器类等。这是在我训练时崩溃的原因(由于“未知原因”)

class LSTMLM(torch.nn.Module):
  def __init__(self,
              vocab_size,
              embedding_size,
              hidden_size,
              num_layers=1,
              dropout=0.1):
    super().__init__()
    self.vocab_size = vocab_size
    self.embedding_size = embedding_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.dropout = torch.nn.Dropout(dropout)

    self.embedding = torch.nn.Embedding(vocab_size, embedding_size)

    self.lstm = torch.nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

    self.logistl = torch.nn.Linear(hidden_size, vocab_size)

    pass

  def forward(self, x, init_hidden_state=None):
    assert x.shape[0] == 1

    emb = self.embedding(x)

    emb = self.dropout(emb)

    if init_hidden_state is None:
      h0 = torch.zeros(self.num_layers, 1, self.hidden_size)
      c0 = torch.zeros(self.num_layers, 1, self.hidden_size)
    else:
      h0, c0 = init_hidden_state
    output, (hn, hc) = self.lstm(emb, (h0, c0))

    hidden_states = output
    final_hidden_state = hn
    final_cell_state = cn
    final_state = [final_hidden__state, final_cell_state]

    hidden_states = self.dropout(hidden_states) 

    output_dist = self.logistl(hidden_states)

    return output_dist, hidden_states, final_state 

超级机器人

解决方案是在定义每个张量之后添加cuda()。

例如

a = a.cuda()

请注意,您必须分配它,而不仅仅是 a.cuda()

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

我的程序一直循环,我不知道为什么

我的根分区一直保持100%的容量,但是我不知道为什么!

我正在尝试登录并检查用户是否存在,但我的执行标量一直返回 null,我不知道为什么?

我一直有一个关于 .nib 文件的错误,我不知道它是如何创建的

PyGame不断崩溃,我不知道为什么

我不知道 sql 语法中的错误是什么

我不知道这个错误是什么意思

我不知道xcode的iOS元素是什么

TTL文件格式-我不知道这是什么

我还不知道的演员类型是什么

我不知道预期的表情是什么

Java ArrayList,我不知道 IndexOutOfBoundsException 是什么

是什么导致我的代码不知道我想从本地存储中删除哪个值?

我不知道这段代码中的这些符号是什么意思。十进制到二进制

我需要初始化一个我以前真的不知道的东西,它在Ruby中是什么

我有一个内存错误,我不知道是什么导致它在 C++ 中

我创建了一个后缀数组,但我不知道这段代码有什么问题

我不知道为什么我的代码是错误的?那是什么错呢?

Pygame的屏幕崩溃崩溃,我不知道为什么

欧拉1号项目。我一直把答案弄错了100个,而且我也不知道为什么(写成F#)

我的电脑一直死机。会是什么呢?

有谁知道为什么这会一直崩溃,当我添加平均值时它开始崩溃

我不知道这段代码有什么问题,谁能帮助我:

我不知道为什么这段代码中有语法错误

我不知道这段 PHP 代码有什么问题

我可以ping IP地址,但是我不知道网络中的设备是什么

我的应用程序崩溃了,但我不知道为什么?

我不知道为什么我的应用程序不停止崩溃

我的代码上出现以下错误,我不知道它们是什么意思,也不知道如何解决它们