具有附加维度的二进制分类(子类)

Nukubiho

假设我有10个观测值(属于A或B类),5列和2个子类(C,D)作为附加维度的数组,并且我想在Keras R中进行二进制分类(至A或B类)在这种情况下,网络架构应该是什么样?

library("keras")

df = data.frame(class = c(rep("A", 10), rep("B", 10)),
                subclass = rep(c("C", "D"), 10),
                feature1 = rnorm(20),
                feature2 = rnorm(20),
                feature3 = rnorm(20))

df1 = df[df$subclass == "C", ]
df2 = df[df$subclass == "D", ]
df_list = list(df1, df2)

build_model = function() {
  model = keras_model_sequential() 

  model %>%
    # input_shape is 3 features and 2 subclasses
    layer_dense(units = 2, activation = 'sigmoid', input_shape = c(3, 2))

  model %>%
    compile(
      optimizer = "adam",
      loss = "binary_crossentropy",
      metrics = list("accuracy")
    )
}

# one hot encoding to A, B classes
labels = to_categorical(as.integer(df_list[[1]][, "class"]) - 1)

# drop factor columns
data = lapply(df_list, function(x) x[, -(1:2)])

# convert to array
data_array = array(unlist(c(data[[1]], data[[2]])), dim = c(10, 3, 2))

model = build_model()

# error appears in the following function:
history = model %>% fit(
  x = data_array,
  y = labels
)

错误:

py_call_impl中的错误(可调用,dots $ args,dots $ keywords):

ValueError:将形状为(10,2)的目标数组传递为形状(None,3,2)的输出,同时用作loss binary_crossentropy这种损失期望目标与输出具有相同的形状。

该错误与输入和输出数据的尺寸之间的差异有关,但我不知道它应该看起来像什么。我的样本数据维是10个观测值,3个特征和2个子类。

型号信息:

Model: "sequential"
____________________________________________________________________
Layer (type)                Output Shape               Param #    
====================================================================
dense (Dense)               (None, 3, 2)               6          
====================================================================
Total params: 6
Trainable params: 6
Non-trainable params: 0
____________________________________________________________________
Nukubiho

在乙状结肠层之前,神经网络体系结构需要一个扁平层。然后代码将起作用。

model %>%
  # input_shape is 3 features and 2 subclasses
  layer_flatten(input_shape = c(3, 2)) %>%
  layer_dense(units = 2, activation = 'sigmoid')

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章