如何使用Java多线程更快地计算Euler数

伊夫林·迪涅夫(Ivelin Dinev):

因此,我有一个任务要使用以下公式使用多个线程来计算欧拉数:sum((((3k)^ 2 +1)/((3k)!))),k = 0 ...无穷大。

import java.math.BigDecimal;
import java.math.BigInteger;
import java.io.FileWriter;
import java.io.IOException;
import java.math.RoundingMode;

class ECalculator {
  private BigDecimal sum;
  private BigDecimal[] series;
  private int length;
  
  public ECalculator(int threadCount) {
    this.length = threadCount;
    this.sum = new BigDecimal(0);
    this.series = new BigDecimal[threadCount];
    for (int i = 0; i < this.length; i++) {
      this.series[i] = BigDecimal.ZERO;
    }
  }
  
  public synchronized void addToSum(BigDecimal element) {
    this.sum = this.sum.add(element);
  }
  
  public void addToSeries(int id, BigDecimal element) {
    if (id - 1 < length) {
      this.series[id - 1] = this.series[id - 1].add(element);      
    }
  }
  
  public synchronized BigDecimal getSum() {
    return this.sum;
  }
  
  public BigDecimal getSeriesSum() {
    BigDecimal result = BigDecimal.ZERO;
    for (int i = 0; i < this.length; i++) {
      result = result.add(this.series[i]);
    }
    return result;
  }
}

class ERunnable implements Runnable {
  private final int id;
  private final int threadCount;
  private final int threadRemainder;
  private final int elements;
  private final boolean quietFlag;
  private ECalculator eCalc;

  public ERunnable(int threadCount, int threadRemainder, int id, int elements, boolean quietFlag, ECalculator eCalc) {
    this.id = id;
    this.threadCount = threadCount;
    this.threadRemainder = threadRemainder;
    this.elements = elements;
    this.quietFlag = quietFlag;
    this.eCalc = eCalc;
  }

  @Override
  public void run() {
    if (!quietFlag) {
      System.out.println(String.format("Thread-%d started.", this.id));      
    }
    long start = System.currentTimeMillis();
    int k = this.threadRemainder;
    int iteration = 0;
    BigInteger currentFactorial = BigInteger.valueOf(intFactorial(3 * k));
    
    while (iteration < this.elements) {
      if (iteration != 0) {
        for (int i = 3 * (k - threadCount) + 1; i <= 3 * k; i++) {
          currentFactorial = currentFactorial.multiply(BigInteger.valueOf(i));
        }
      }
      
      this.eCalc.addToSeries(this.id, new BigDecimal(Math.pow(3 * k, 2) + 1).divide(new BigDecimal(currentFactorial), 100, RoundingMode.HALF_UP));
      
      iteration += 1;
      k += this.threadCount;
    }
    
    long stop = System.currentTimeMillis();
    if (!quietFlag) {
      System.out.println(String.format("Thread-%d stopped.", this.id));
      System.out.println(String.format("Thread %d execution time: %d milliseconds", this.id, stop - start));      
    }
  }
  
  public int intFactorial(int n) {
    int result = 1;
    for (int i = 1; i <= n; i++) {
      result *= i;
    }
    return result;
  }
}

public class TaskRunner {
  public static final String DEFAULT_FILE_NAME = "result.txt";
  public static void main(String[] args) throws InterruptedException {

    int threadCount = 2;
    int precision = 10000;
    int elementsPerTask = precision / threadCount;
    int remainingElements = precision % threadCount;
    boolean quietFlag = false;
    
    calculate(threadCount, elementsPerTask, remainingElements, quietFlag, DEFAULT_FILE_NAME);
  }
  
  public static void writeResult(String filename, String result) {
    try {
      FileWriter writer = new FileWriter(filename);
      writer.write(result);
      writer.close();
    } catch (IOException e) {
      System.out.println("An error occurred.");
      e.printStackTrace();
    }
  }
  
  public static void calculate(int threadCount, int elementsPerTask, int remainingElements, boolean quietFlag, String outputFile) throws InterruptedException {
    long start = System.currentTimeMillis();
    Thread[] threads = new Thread[threadCount];
    ECalculator eCalc = new ECalculator(threadCount);
    
    for (int i = 0; i < threadCount; i++) {
      if (i == 0) {
        threads[i] = new Thread(new ERunnable(threadCount, i, i + 1, elementsPerTask + remainingElements, quietFlag, eCalc));
      } else {
        threads[i] = new Thread(new ERunnable(threadCount, i, i + 1, elementsPerTask, quietFlag, eCalc));        
      }
      threads[i].start();
    }
    
    for (int i = 0; i < threadCount; i++) {
      threads[i].join();
    }
    
    String result = eCalc.getSeriesSum().toString();
    
    if (!quietFlag) {
      System.out.println("E = " + result);      
    }
    
    writeResult(outputFile, result);
    
    long stop = System.currentTimeMillis();
    System.out.println("Calculated in: " + (stop - start) + " milliseconds" );
  }
}

我在代码中去除了无效的打印内容等。我的问题是,我使用的线程越多,它获得的速度就越慢。目前,我运行最快的是1个线程。我确信阶乘计算会引起一些问题。我尝试使用线程池,但是仍然有相同的时间。

  1. 如何使它运行到更多的线程,直到某个时候,将加快计算过程?
  2. 如何计算这么大的阶乘?
  3. 传递的precision参数是所使用的总和中元素的数量。是否可以将BigDecimal比例尺设置为某种程度上取决于该精度,所以我不对其进行硬编码?

编辑我将代码块更新为仅在1个文件中,并且无需外部库即可运行。

编辑2我发现阶乘代码弄乱了时间。如果我在不计算阶乘的情况下让线程上升到某个高精度,则时间会随着线程增加而减少。但是,在保持时间减少的同时,我无法以任何方式实施阶乘计算。

编辑3调整代码以解决答案。

    private static BigDecimal partialCalculator(int start, int threadCount, int id) {
        BigDecimal nBD = BigDecimal.valueOf(start);
        
        BigDecimal result = nBD.multiply(nBD).multiply(BigDecimal.valueOf(9)).add(BigDecimal.valueOf(1));
        
        for (int i = start; i > 0; i -= threadCount) {
            BigDecimal iBD = BigDecimal.valueOf(i);
            BigDecimal iBD1 = BigDecimal.valueOf(i - 1);
            BigDecimal iBD3 = BigDecimal.valueOf(3).multiply(iBD);
            
            BigDecimal prevNumerator = iBD1.multiply(iBD1).multiply(BigDecimal.valueOf(9)).add(BigDecimal.valueOf(1));
            
            // 3 * i * (3 * i - 1) * (3 * i - 2);
            BigDecimal divisor = iBD3.multiply(iBD3.subtract(BigDecimal.valueOf(1))).multiply(iBD3.subtract(BigDecimal.valueOf(2)));
            result = result.divide(divisor, 10000, RoundingMode.HALF_EVEN)
                                     .add(prevNumerator);
        }
        return result;
    }
    
    public static void main(String[] args) {
        int threadCount = 3;
        int precision = 6;
        
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        ArrayList<Future<BigDecimal> > futures = new ArrayList<Future<BigDecimal> >();
        for (int i = 0; i < threadCount; i++) {
            int start = precision - i;
            System.out.println(start);
            final int id = i + 1;
            futures.add(executorService.submit(() -> partialCalculator(start, threadCount, id)));
        }
        BigDecimal result = BigDecimal.ZERO;
        try {
            for (int i = 0; i < threadCount; i++) {
                result = result.add(futures.get(i).get());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        
        executorService.shutdown();
        System.out.println(result);
    }

似乎对于1个线程可以正常工作,但会使多个计算混乱。

乔尼:

在检查了更新的代码之后,我做了以下观察:

首先,该程序运行了不到一秒钟的时间。这意味着这是一个微观基准。Java中的几个关键功能使微基准测试难以可靠地实现。请参阅如何在Java中编写正确的微基准?例如,如果程序没有运行足够的重复,那么“及时”编译器将没有时间投入将其编译为本地代码,从而最终对解释器进行基准测试。在您的情况下,当有多个线程时,JIT编译器可能需要更长的时间才能启动,

例如,为了使您的程序执行更多工作,我将BigDecimal精度从100 更改为10,000,并在main方法周围添加了一个循环。执行时间的度量如下:

1个线程:

Calculated in: 2803 milliseconds
Calculated in: 1116 milliseconds
Calculated in: 1040 milliseconds
Calculated in: 1066 milliseconds
Calculated in: 1036 milliseconds

2个线程:

Calculated in: 2354 milliseconds
Calculated in: 856 milliseconds
Calculated in: 624 milliseconds
Calculated in: 659 milliseconds
Calculated in: 664 milliseconds

4个线程:

Calculated in: 1961 milliseconds
Calculated in: 797 milliseconds
Calculated in: 623 milliseconds
Calculated in: 536 milliseconds
Calculated in: 497 milliseconds

第二个观察结果是,很大一部分工作负载无法从多个线程中受益:每个线程都在计算每个阶乘。这意味着加速不能线性化-如阿姆达尔定律所述

那么,如何在不计算阶乘的情况下获得结果呢?一种方法是使用霍纳法。例如,考虑比较简单的系列sum(1/k!),它收敛到e但比您的慢一些。

假设您要计算的sum(1/k!)最大k =100。使用霍纳方法,您从头开始并提取公因子:

sum(1/k!, k=0..n) = 1/100! + 1/99! + 1/98! + ... + 1/1! + 1/0!
= ((... (((1/100 + 1)/99 + 1)/98 + ...)/2 + 1)/1 + 1

看看如何从1开始,除以100,然后加1,再除以99,再加上1,除以98,再加上1,依此类推?这使得程序非常简单:

private static BigDecimal serialHornerMethod() {
    BigDecimal accumulator = BigDecimal.ONE;
    for (int k = 10000; k > 0; k--) {
        BigDecimal divisor = new BigDecimal(k);
        accumulator = accumulator.divide(divisor, 10000, RoundingMode.HALF_EVEN)
                                 .add(BigDecimal.ONE);
    }
    return accumulator;
}

好的,这是一种串行方法,如何使用并行方法?这是两个线程的示例:首先将序列分为偶数和奇数项:

1/100! + 1/99! + 1/98! + 1/97! + ... + 1/1! + 1/0! =
(1/100! + 1/98! + ...  + 1/0!) + (1/99! + 1/97! + ... + 1/1!)

然后将Horner的方法应用于偶数和奇数项:

1/100! + 1/98! + 1/96! + ...  + 1/2! + 1/0! = 
((((1/(100*99) + 1)/(98*97) + 1)/(96*95) + ...)/(2*1) + 1

and:

1/99! + 1/97! + 1/95! + ... + 1/3! + 1/1! =
((((1/(99*98) + 1)/(97*96) + 1)/(95*94) + ...)/(3*2) + 1

这与串行方法一样容易实现,并且从1线程到2线程的线性加速非常接近:

    private static BigDecimal partialHornerMethod(int start) {
        BigDecimal accumulator = BigDecimal.ONE;
        for (int i = start; i > 0; i -= 2) {
            int f = i * (i + 1);
            BigDecimal divisor = new BigDecimal(f);
            accumulator = accumulator.divide(divisor, 10000, RoundingMode.HALF_EVEN)
                                     .add(BigDecimal.ONE);
        }
        return accumulator;
    }

// Usage:

ExecutorService executorService = Executors.newFixedThreadPool(2);
Future<BigDecimal> submit = executorService.submit(() -> partialHornerMethod(10000));
Future<BigDecimal> submit1 = executorService.submit(() -> partialHornerMethod(9999));
BigDecimal result = submit1.get().add(submit.get());

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章