使用keras模型解决多标签问题的scikit学习链分类器的拟合方法错误

Aizayousaf

我正在为使用KerasClassifier模型的多类问题构建链分类器。我有17个标签作为分类目标,X_train的形状为(111300,107),y_train的形状为(111300,17)我的代码在这里:

def create_model():
  input_size=length_long_sentence
  embedding_size=128
  lstm_size=64
  output_size=len(unique_tag_set)
    #----------------------------Model -------------------------------
  current_input=Input(shape=(input_size,)) 
  emb_current = Embedding(vocab_size, embedding_size, input_length=input_size)(current_input)
  out_current=Bidirectional(LSTM(units=lstm_size))(emb_current )
  #out_current = Reshape((1,2*lstm_size))(out_current)
  output = Dense(units=len(unique_tag_set), activation='softmax')(out_current)
  model = Model(inputs=current_input, outputs=output)
  model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
  print(model.summary())
  return model

   model = KerasClassifier(build_fn=create_model, epochs=1,batch_size=256)
   print(type(model))
   chain=ClassifierChain(model, order='random', random_state=42)
   history=chain.fit(X_train, y_train)

模型摘要在这里:

在此处输入图片说明

当尝试在ClassifierChain上使用fit方法时,出现此错误:

在此处输入图片说明

任何人都可以指导我这个错误,什么是(None,2)?

文卡塔恰兰

从链式分类器的记录中:

将二元分类器排列成一个链的多标签模型。

因此,使用最后一层中的单个节点将keras模型转换为二进制分类器,并将损失函数转换为binary_crossentropy

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

拟合多标签文本分类模型时的错误

使用 keras 拟合深度学习模型

Scikit学习多标签分类,从MultiLabelBinarizer获取标签

并行拟合scikit学习模型?

使用scikit-learn进行多标签文本分类,使用哪些分类器?

Scikit学习多标签分类:ValueError:您似乎正在使用旧的多标签数据表示形式

在 scikit-learn 中使用图像数据拟合支持向量分类器会产生错误

Keras多标签分类“ to_categorical”错误

使用保存的分类器/模型时出现“ IDF向量未拟合”错误

使用scikit学习训练Logistic回归进行多类别分类

拟合scikit学习决策树和随机森林分类器时出现MemoryError

使用 TensoFlow 的多类分类标签错误

软标签上的scikit学习分类

多标签分类Keras指标

使用深度学习防止在多类别分类中过度拟合特定类别

计算ROC曲线,分类报告和混淆矩阵以解决多标签分类问题

如何使用keras实现多标签分类神经网络

多标签分类:如何学习阈值?

如何从Scikit学习的拟合模型中获取属性列表?

拟合分类器中的奇怪错误

使用Naive Bayes进行10类交叉验证的Scikit学习进行多类分类

使用WEKA在多标签设置中使用kNN分类器

Keras-分类器无法从预先训练的模型的传递值中学习

Keras MLP分类器无法学习

分类器中的scikit-learn改装/部分拟合选项

如何使用Keras训练多类图像分类器

使用混淆矩阵了解多标签分类器

使用以下工具的Scikit学习多输出分类器:GridSearchCV,管道,OneVsRestClassifier,SGDClassifier

使用多个分类器时-如何衡量整体表现?[SciKit学习]