Kmeans聚类如何在张量流中工作?

兹蒂克·本·沙巴特

我看到在tensorflow contrib库中有一个Kmeans集群的实现。但是,我无法简单地估算2D点的聚类中心。

码:

## Generate synthetic data
N,D = 1000, 2 # number of points and dimenstinality

means = np.array([[0.5, 0.0],
                  [0, 0],
                  [-0.5, -0.5],
                  [-0.8, 0.3]])
covs = np.array([np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01]),
                 np.diag([0.01, 0.01])])
n_clusters = means.shape[0]

points = []
for i in range(n_clusters):
    x = np.random.multivariate_normal(means[i], covs[i], N )
    points.append(x)
points = np.concatenate(points)

## construct model
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters)
kmeans.fit(points.astype(np.float32))

我收到以下错误:

InvalidArgumentError (see above for traceback): Shape [-1,2] has negative dimensions
     [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

我想我做错了什么,但无法从文档中找出答案。

编辑

我使用来解决它,input_fn但是它确实很慢(我不得不将每个群集中的点数减少到10以查看结果)。为什么会这样,我如何使其更快?

 def input_fn():
    return tf.constant(points, dtype=tf.float32), None

## construct model
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)
kmeans.fit(input_fn=input_fn)
centers = kmeans.clusters()
print(centers)

解决了:

似乎应该设定相对公差。所以我只更改了一条线,效果很好。kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)

和萨罗

您的原始代码使用Tensorflow 1.2返回以下错误:

    WARNING:tensorflow:From <stdin>:1: calling BaseEstimator.fit (from         
    tensorflow.contrib.learn.python.learn.estimators.estimator) with x 
    is deprecated and will be removed after 2016-12-01.
    Instructions for updating:
    Estimator is decoupled from Scikit Learn interface by moving into
    separate class SKCompat. Arguments x, y and batch_size are only
    available in the SKCompat class, Estimator will only accept input_fn.

根据您的编辑,您似乎已经确定这input_fn是唯一可接受的输入。如果您真的想使用TF,我将升级到r1.2并将Estimator包装在SKCompat类中,这是错误消息所提示的。否则,我只会使用SKLearn软件包。您也可以手动在TF中实现自己的聚类算法,如本博客所示

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章