如何使用均值和标准差从pyspark中的多个列中删除离群值

我有下面的数据框,我想从定义的列中删除离群值。在下面的示例中,价格和收入。对于每组数据,应删除异常值。在此示例中,其“ cd”和“ segment”列。应基于5个标准偏差删除异常值。

data = [
  ('a', '1',20,10),   
  ('a', '1',30,16),
  ('a', '1',50,91),
    ('a', '1',60,34),
    ('a', '1',200,23),
  ('a', '2',33,87),
  ('a', '2',86,90),
      ('a','2',89,35),
    ('a', '2',90,24),
    ('a', '2',40,97),
  ('a', '2',1,21),
  ('b', '1',45,96),   
  ('b', '1',56,99),
  ('b', '1',89,23),
    ('b', '1',98,64),
    ('b', '2',86,42),
  ('b', '2',45,54),
  ('b', '2',67,95),
      ('b','2',86,70),
    ('b', '2',91,64),
    ('b', '2',2,53),
  ('b', '2',4,87)
]
data = (spark.createDataFrame(data, ['cd','segment','price','income']))

我已使用以下代码删除异常值,但这仅适用于一列。

mean_std = (
    data
    .groupBy('cd', 'segment')
    .agg(
      *[f.mean(colName).alias('{}{}'.format('mean_',colName)) for colName in ['price']],
      *[f.stddev(colName).alias('{}{}'.format('stddev_',colName)) for colName in ['price']])
)


mean_columns = ['mean_price']
std_columns = ['stddev_price']
upper = mean_std
for col_1 in mean_columns:
    for col_2 in std_columns:
      if col_1 != col_2:
        name = col_1 + '_upper_limit'
        upper = upper.withColumn(name, f.col(col_1) + f.col(col_2)*5)
        
lower = upper
for col_1 in mean_columns:
    for col_2 in std_columns:
      if col_1 != col_2:
        name = col_1 + '_lower_limit'
        lower = lower.withColumn(name, f.col(col_1) - f.col(col_2)*5)
        
outliers = (data.join(lower, 
                                how = 'left',
                                on = ['cd', 'segment'])
                           .withColumn('is_outlier_price', f.when((f.col('price')>f.col('mean_price_upper_limit')) |
                                                           (f.col('price')<f.col('mean_price_lower_limit')),1)
                                                      .otherwise(None))
           )

我的最终输出应该为每个变量都有一个列,说明其变量1 = remove还是0 = keep。

非常感谢您对此提供的任何帮助。

维纳

您的代码几乎可以100%正常工作。您要做的就是用一个列名数组替换单个固定的列名,然后遍历该数组:

numeric_cols = ['price', 'income']
mean_std = \
    data \
    .groupBy('cd', 'segment') \
    .agg( \
      *[F.mean(colName).alias('mean_{}'.format(colName)) for colName in numeric_cols],\
      *[F.stddev(colName).alias('stddev_{}'.format(colName)) for colName in numeric_cols])

mean_std现在是一个数据框,其中每个元素有两列(mean_...stddev_...numeric_cols

在下一步中,我们计算以下元素的上下限numeric_cols

mean_std_min_max = mean_std
for colName in numeric_cols:
    meanCol = 'mean_{}'.format(colName)
    stddevCol = 'stddev_{}'.format(colName)
    minCol = 'min_{}'.format(colName)
    maxCol = 'max_{}'.format(colName)
    mean_std_min_max = mean_std_min_max.withColumn(minCol, F.col(meanCol) - 5 * F.col(stddevCol))
    mean_std_min_max = mean_std_min_max.withColumn(maxCol, F.col(meanCol) + 5 * F.col(stddevCol))

mean_std_min_max现在包含两个附加列min_...max...每个元素numeric_cols

最后进行联接,然后is_outliers_...像以前一样计算列:

outliers = data.join(mean_std_min_max, how = 'left', on = ['cd', 'segment'])
for colName in numeric_cols:
    isOutlierCol = 'is_outlier_{}'.format(colName)
    minCol = 'min_{}'.format(colName)
    maxCol = 'max_{}'.format(colName)
    meanCol = 'mean_{}'.format(colName)
    stddevCol = 'stddev_{}'.format(colName)
    outliers = outliers.withColumn(isOutlierCol, F.when((F.col(colName) > F.col(maxCol)) | (F.col(colName) < F.col(minCol)), 1).otherwise(0))    
    outliers = outliers.drop(minCol,maxCol, meanCol, stddevCol)

循环的最后一行仅是清理和删除中间列。注释掉可能会有所帮助。

最终结果是:

+---+-------+-----+------+----------------+-----------------+
| cd|segment|price|income|is_outlier_price|is_outlier_income|
+---+-------+-----+------+----------------+-----------------+
|  b|      2|   86|    42|               0|                0|
|  b|      2|   45|    54|               0|                0|
|  b|      2|   67|    95|               0|                0|
|  b|      2|   86|    70|               0|                0|
|  b|      2|   91|    64|               0|                0|
+---+-------+-----+------+----------------+-----------------+
only showing top 5 rows

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

如何使用 R 中的均值和标准差替换缺失值?

如何从列表中获取均值、中值和标准差

通过均值的标准差查找离群值,在大型数据集中(6000多个列)替换为NA

如何在data.table中创建均值和标准差列

给定PySpark DataFrame如何计算均值和标准差?

如何计算多个分组变量的均值和标准差?

使用numpy的多个数组的均值和标准差

在R data.table中,如何使用训练集的均值和标准差标准化测试集

如何使用 CSV 文件绘制均值和标准差?

计算时间序列中的均值和标准差

字典列表中的 Python 均值和标准差

ggplot 中的多条(均值和标准差)

如何使用逐元素运算获取多个numpy保存数组的均值和标准差

基于单个字典中的键的值的均值和标准差

Random.nextgaussian()可以从均值和标准差不同的分布中采样值吗?

如何计算字典中多个矩阵的均值/中位数/标准差?

将某些列重新调整为 R 中的特定均值和标准差

如何在ggplot中的箱线图上打印均值,中位数和标准差?

仅使用均值和标准差信息在 R 中绘制密度图

多个数据帧的均值和标准差

在python / pyspark中获取k均值质心和离群值

查找行的均值和标准差,直到R中的组的下一个NA值

插值数据集的均值和标准差 (R)

在Excel中,当可用数据是值和出现次数时,如何查找标准差?

数据帧非零列的均值和标准差

计算列均值和标准差的组

R:每个受试者的均值,方差和标准差列

熊猫计算两列的均值,标准差和计数

数据框列的均值和标准差