Fork-Join模式

什么是Fork-Join

“分治”问题可以很容易地通过Callable线程的Executor接口来解决。通过为每个任务实例化一 个Callable实例,并在ExecutorService类中汇总计算结果来得出最终结果可以实现这一目的。那么自然而然想到的问题就是,如果这接口已经做得不错了,我们为什么还需要Java 7的其他框架?
使用ExecutorServiceCallable的主要问题是,Callable实例在本质上是阻塞的。一旦一个Callable实例开始执行,其他所有Callable都会被阻塞。由于队列后面的Callable实例在前一实例未执行完成的时候不会被执行,因此许多资源无法得到利用。Fork-Join框架被引入来解决这一并行问题,而Executor解决的是并发问题(译者注:并发和并行的区别就是一个处理器同时处理多个任务和多个处理器或者是多核的处理器同时处理多个不同的任务)。
Fork-Join模式,分而治之,然后合并结果,这么一种编程模式。(注:Fork-Join是一个单机框架,类似的分布式的框架有Hadoop这类的,它们的计算模型是MapReduce,体现了和Fork-Join一样的思想-分而治之。)
Fork-Join框架是一个”多核友好的、轻量级并行框架”,它支持并行编程风格,将问题递归拆分成多个更小片断,以并行和调配的方式解决。Fork-join融合了分而治之技术;获取问题后,递归地将它分成多个子问题,直到每个子问题都足够小,以至于可以高效地串行地解决它们。递归的过程将会把问题分成两个或者多个子问题,然后把这些问题放入队列中等待处理(fork步骤),接下来等待所有子问题的结果(join步骤),把多个结果合并到一起。
Fork-Join模式有自己的适用范围。如果一个应用能被分解成多个子任务,并且组合多个子任务的结果就能够获得最终的答案,那么这个应用就适合用Fork-Join模式来解决。
一个Fork-Join模式的示意图,位于图上部的Task依赖于位于其下的Task的执行,只有当所有的子任务都完成之后,调用者才能获得Task 0的返回结果。如下图。


Fork-Join模式能够解决很多种类的并行问题。通过使用Doug Lea提供的Fork-Join框架,软件开发人员只需要关注任务的划分和中间结果的组合就能充分利用并行平台的优良性能。其他和并行相关的诸多难于处理的问题,例如负载平衡、同步等,都可以由框架采用统一的方式解决。这样,我们就能够轻松地获得并行的好处而避免了并行编程的困难且容易出错的缺点。

work-stealing算法

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。工作窃取的运行流程图如下。

那么为什么需要使用工作窃取算法呢?假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。
工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争,其缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。
Fork-join只有你在将一个任务拆分成小任务时才有用处。Fork-Join池是是一个work-stealing工作窃取线程池。每个工作线程维护本地任务队列。
Fork-join池里的线程不是在等待新任务,而是主动分裂的现有任务到更小的,并帮助完成其他线程的大任务(切分以后)。如图所示。

work-stealing所采用的基本调度策略。

  1. 每一个工作线程维护自己的调度队列中的可运行任务。
  2. 队列以双端队列的形式被维护(注:deques通常读作”decks”),不仅支持后进先出——LIFO的pushpop操作,还支持先进先出——FIFO的take操作。
  3. 对于一个给定的工作线程来说,任务所产生的子任务将会被放入到工作者自己的双端队列中。
  4. 工作线程使用后进先出——LIFO的顺序,通过弹出任务来处理队列中的任务。
  5. 当一个工作线程的本地没有任务去运行的时候,它将使用先进先出——FIFO的规则尝试随机的从别的工作线程中拿(“偷窃”)一个任务去运行。
  6. 当一个工作线程触及了join操作,如果可能的话它将处理其他任务,直到目标任务被告知已经结束(通过isDone())。所有的任务都会无阻塞的完成。
  7. 当一个工作线程无法再从其他线程中获取任务和失败处理的时候,它就会退出(通过yields, sleeps, 和/或者优先级调整)并经过一段时间之后再度尝试直到所有的工作线程都被告知他们都处于空闲的状态。在这种情况下,他们都会阻塞直到其他的任务再度被上层调用。
  8. 使用后进先出——LIFO用来处理每个工作线程的自己任务,但是使用先进先出——FIFO规则用于获取别的任务,这是一种被广泛使用的进行递归Fork-Join设计的一种调优手段。

让偷取任务的线程从队列拥有者相反的方向进行操作会减少线程竞争。同样体现了递归分治算法的大任务优先策略。因此,更早期被偷取的任务有可能会提供一个更大的单元任务,从而使得偷取线程能够在将来进行递归分解。

Fork-Join框架

我们已经很清楚Fork-Join框架的需求了,那么我们可以思考一下,如果让我们来设计一个Fork-Join框架,该如何设计?这个思考有助于你理解Fork-Join框架的设计。

  1. 分割任务。首先我们需要有一个fork类来把大任务分割成子任务,有可能子任务还是很大,所以还需要不停的分割,直到分割出的子任务足够小。
  2. 执行任务并合并结果。分割的子任务分别放在双端队列(线程,队列一一对应)里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据。


Fork-Join使用两个类来完成以上两件事情。

ForkJoinTask

我们要使用Fork-Join框架,必须首先创建一个Fork-Join任务。它提供在任务中执行fork()join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork-Join框架提供了以下两个子类。

  • RecursiveAction:用于没有返回结果的任务。
  • RecursiveTask:用于有返回结果的任务。

ForkJoinTask有两个主要的方法。

fork()

这个方法决定了ForkJoinTask的异步执行,凭借这个方法可以创建新的任务。

join()

该方法负责在计算完成侯返回结果,因此允许一个任务等待另一任务执行完成。

ForkJoinPool

ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务(work-stealing)。ForkJoinPool会尝试在任何时候都维持与可用的处理器数目一样数目的活动线程数。

注意

可用线程数和硬件支持。线程这东西,也是有开销的东西,绝对不是越多越好,尤其在硬件基础有限的情况下。
任务分解的粒度。和前者有关系,就是分解的任务,“小”到什么程度是可以接受的,不可再分。切分到多少才合适呢?一般切分到一个阈值,再切分下去就没有意义的了。

Fork-Join框架使用

extends RecursiveTask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package com.forkjoin2;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;

public class CountTask extends RecursiveTask<Object> {
private static final long serialVersionUID = 1L;

private static final int THRESHOLD = 2;

private int start;
private int end;

public CountTask(int start, int end) {
this.start = start;
this.end = end;
}

@Override
protected Integer compute() {
int sum = 0;
boolean canCompute = (end - start) <= THRESHOLD;
if (canCompute) {
for (int i = start; i <= end; i++) {
sum += i;
}
} else {
// 如果任务大于阀值,就分裂成两个子任务计算
int mid = (start + end) / 2;
CountTask leftTask = new CountTask(start, mid);
CountTask rightTask = new CountTask(mid + 1, end);

//异步的执行子任务
leftTask.fork();
rightTask.fork();

// 等待子任务执行完,并得到结果
int rightResult = (int) rightTask.join();
int leftResult = (int) leftTask.join();
sum = leftResult + rightResult;
}
return sum;
}

public static void main(String[] args) {
ForkJoinPool forkJoinPool = new ForkJoinPool();
// 生成一个计算资格,负责计算1+2+3+4
CountTask task = new CountTask(1, 4);
@SuppressWarnings({ "rawtypes" })
Future result = forkJoinPool.submit(task);
try {
System.out.println(result.get());
} catch (Exception e) {
e.printStackTrace();
}
}
}

在Fork-Join框架中,提交任务的时候,有同步和异步两种方式。以前使用的invokeAll()是同步的,也就是任务提交后,这个方法不会返回直到所有的任务都处理完了。
而还有另一种方式,就是使用fork(),这个是异步的。也就是你提交任务后,fork()立即返回,可以继续下面的任务。这个线程也会继续运行。
下面我们以一个查询磁盘的以log结尾的文件的程序例子来说明异步的用法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package com.forkjoin5;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

/*
*
*/
public class FolderProcessor extends RecursiveTask<List<String>> {
private static final long serialVersionUID = 1L;

private String path;// 路径
private String extension;// 文件后缀名

public FolderProcessor(String path, String extension) {
super();
this.path = path;
this.extension = extension;
}

// 任务执行
@Override
protected List<String> compute() {
// 符合文件搜索条件的list,添加文件的名字
List<String> list = new ArrayList<String>();
// 任务列表
List<FolderProcessor> taskList = new ArrayList<FolderProcessor>();
File file = new File(path);
// 子文件列表
File[] content = file.listFiles();
if (content != null) {
for (int i = 0; i < content.length; i++) {
//如果是目录的情况,继续搜索子文件
if (content[i].isDirectory()) {
//子文件任务
FolderProcessor task = new FolderProcessor(content[i].getAbsolutePath(),extension);
// 异步方式提交任务
task.fork();
taskList.add(task);
} else {
//非目录
if (checkFile(content[i].getName())) {
list.add(content[i].getAbsolutePath());
}
}
}
}

if (taskList.size() > 50) {
System.out.printf("%s: %d tasks ran.\n", file.getAbsolutePath(), taskList.size());
}

addResultsFromTasks(list, taskList);
return list;
}

//添加文件的名字
private void addResultsFromTasks(List<String> list, List<FolderProcessor> taskList) {
for (FolderProcessor item : taskList) {
list.addAll(item.join());
}
}

//检测文件
private boolean checkFile(String name) {
return name.endsWith(extension);
}

// 实现 showLog() 方法。它接收 ForkJoinPool 对象作为参数和写关于线程和任务的执行的状态的信息。
private static void showLog(ForkJoinPool pool) {
System.out.printf("**********************\n");
System.out.printf("Main: Fork/Join Pool log\n");
//此方法返回池的并行的级别。
System.out.printf("Main: Fork/Join Pool: Parallelism:%d\n", pool.getParallelism());
//此方法返回 int 值,它是ForkJoinPool内部线程池的worker线程们的数量。
System.out.printf("Main: Fork/Join Pool: Pool Size:%d\n", pool.getPoolSize());
//此方法返回当前执行任务的线程的数量。
System.out.printf("Main: Fork/Join Pool: Active Thread Count:%d\n", pool.getActiveThreadCount());
//此方法返回没有被任何同步机制阻塞的正在工作的线程。
System.out.printf("Main: Fork/Join Pool: Running Thread Count:%d\n", pool.getRunningThreadCount());
//此方法返回已经提交给池还没有开始他们的执行的任务数。
System.out.printf("Main: Fork/Join Pool: Queued Submission:%d\n", pool.getQueuedSubmissionCount());
//此方法返回已经提交给池已经开始他们的执行的任务数。
System.out.printf("Main: Fork/Join Pool: Queued Tasks:%d\n", pool.getQueuedTaskCount());
//此方法返回 Boolean 值,表明这个池是否有queued任务还没有开始他们的执行。
System.out.printf("Main: Fork/Join Pool: Queued Submissions:%s\n", pool.hasQueuedSubmissions());
//此方法返回 long 值,worker 线程已经从另一个线程偷取到的任务数。
System.out.printf("Main: Fork/Join Pool: Steal Count:%d\n", pool.getStealCount());
//此方法返回 Boolean 值,表明 fork/join 池是否已经完成执行。
System.out.printf("Main: Fork/Join Pool: Terminated :%s\n", pool.isTerminated());
System.out.printf("**********************\n");
}

public static void main(String[] args) throws InterruptedException {
ForkJoinPool pool = new ForkJoinPool();
FolderProcessor system = new FolderProcessor("F:\\Workspace", "java");
FolderProcessor apps = new FolderProcessor("F:\\Workspace", "jsp");

pool.execute(system);
pool.execute(apps);
while (!apps.isDone() || !apps.isDone()) {
showLog(pool);
TimeUnit.SECONDS.sleep(5000);
}
pool.shutdown();
List<String> results = system.join();
System.out.printf("System: %d files found.\n", results.size());
results = apps.join();
System.out.printf("Apps: %d files found.\n", results.size());
}
}

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package com.forkjoin7;

import java.util.concurrent.RecursiveTask;


//递归的例子
class Fibonacci extends RecursiveTask<Integer> {
private static final long serialVersionUID = 1L;
final int n;

Fibonacci(int n) {
this.n = n;
}

private int compute(int small) {
final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
return results[small];
}

public Integer compute() {
if (n <= 10) {
return compute(n);
}
Fibonacci f1 = new Fibonacci(n - 1);
Fibonacci f2 = new Fibonacci(n - 2);
f1.fork(); //子任务异步执行
f2.fork();
//join : 阻塞等待结果完成
return f1.join() + f2.join();
}
}
1
2
3
4
5
6
public static void main(String[] args) throws InterruptedException, ExecutionException {
ForkJoinTask<Integer> fjt = new Fibonacci(45);
ForkJoinPool fjpool = new ForkJoinPool();
Future<Integer> result = fjpool.submit(fjt);
System.out.println("Foke/Join = " + result.get());
}

自定义forkjoin任务类

MyWorkerTask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package com.forkjoin8;

import java.util.Date;
import java.util.concurrent.ForkJoinTask;

public abstract class MyWorkerTask extends ForkJoinTask<Void> {
private static final long serialVersionUID = -1153949034138340822L;

private String name;

public MyWorkerTask() {
}

public MyWorkerTask(String name) {
this.name = name;
}

// 1
@Override
public Void getRawResult() {
return null;
}

// 2
@Override
protected void setRawResult(Void value) {

}

// 3
@Override
protected boolean exec() {
Date startDate = new Date();
compute();
Date finishDate = new Date();
long diff = finishDate.getTime() - startDate.getTime();
System.out.printf("MyWorkerTask: %s : %d Milliseconds to complete.\n", name, diff);
return true;
}

// 4
public String getName() {
return name;
}

// 5
protected abstract void compute();
}

Task

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package com.forkjoin8;

public class Task extends MyWorkerTask {
private static final long serialVersionUID = -1773159586852826490L;

private int array[];
private int start;
private int end;

public Task(String name, int array[], int start, int end) {
super(name);
this.array = array;
this.start = start;
this.end = end;
}

// 6
protected void compute() {
if (end - start > 100) {
int mid = (end + start) / 2;
Task task1 = new Task(this.getName() + "1", array, start, mid);
Task task2 = new Task(this.getName() + "2", array, mid, end);
invokeAll(task1, task2);
} else {//7
for (int i = start; i < end; i++) {
array[i]++;
}
//最后,让正在执行任务的线程进入休眠50毫秒。
try {
Thread.sleep(50);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}

MainTest2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package com.forkjoin8;

import java.util.concurrent.ForkJoinPool;

//如何为 Fork/Join 框架实现你自己的任务,实现一个任务扩展ForkJoinTask类。
//你将要实现的任务是计量运行时间并写入操控台,这样你可以控制它的进展(evolution)
public class MainTest2 {
public static void main(String[] args) throws Exception {
int array[] = new int[10000];
ForkJoinPool pool = new ForkJoinPool();
Task task = new Task("Task", array, 0, array.length);
pool.invoke(task);
pool.shutdown();
System.out.printf("Main: End of the program.\n");
}
}

  1. 实现getRawResult()。这是ForkJoinTask类的抽象方法之一。由于任务不会返回任何结果,此方法返回的一定是null值。
  2. 实现setRawResult()。这是ForkJoinTask类的另一个抽象方法。由于任务不会返回任何结果,方法留白即可。
  3. 实现exec()抽象方法。这是任务的主要方法。在这个例子,把任务的算法委托给compute()。计算方法的运行时间并写入操控台。
  4. 实现getName()来返回任务的名字。
  5. 声明抽象方法compute()。像我们之前提到的,此方法实现任务的算法,必须是由MyWorkerTask类的子类实现。
  6. 实现compute()。此方法通过start和end。属性来决定增加array的元素块。如果元素块的元素超过100个,把它分成2部分,并创建2个Task对象来处理各个部分。再使用invokeAll()把这些任务发送给池。
  7. 如果元素块的元素少于100,使用for循环增加全部的元素。

创建线程池

  1. 创建MyWorkerThreadFactory实现ForkJoinWorkerThreadFactory工厂类。
  2. 创建自定义工作者线程MyWorkerThread,继承ForkJoinWorkerThread
  3. 创建工作者任务类MyRecursiveTask

MyWorkerThreadFactory

1
2
3
4
5
6
7
8
9
10
11
12
13
package com.forkjoin9;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;

public class MyWorkerThreadFactory implements ForkJoinWorkerThreadFactory{

@Override
public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new MyWorkerThread(pool);
}
}

MyWorkerThread

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package com.forkjoin9;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;

public class MyWorkerThread extends ForkJoinWorkerThread {

// 每个线程都有自己的任务计数器
private static ThreadLocal<Integer> taskCounter = new ThreadLocal<Integer>();

protected MyWorkerThread(ForkJoinPool pool) {
super(pool);
}

@Override
protected void onStart() {
super.onStart();
System.out.printf("MyWorkerThread %d: Initializing taskcounter.\n", getId());
taskCounter.set(0);
}

// 中断
@Override
protected void onTermination(Throwable exception) {
System.out.printf("MyWorkerThread %d:%d\n", getId(), taskCounter.get());
super.onTermination(exception);
}

public void addTask() {
int counter = taskCounter.get().intValue();
counter++;
taskCounter.set(counter);
}

}

MyRecursiveTask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package com.forkjoin9;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class MyRecursiveTask extends RecursiveTask<Integer> {
private static final long serialVersionUID = -6615653526171656238L;

private int array[];
private int start, end;

public MyRecursiveTask(int array[], int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected Integer compute() {
Integer ret = 0;
MyWorkerThread thread = (MyWorkerThread) Thread.currentThread();
thread.addTask();
if (end - start > 100) {
int mid = (start + end) / 2;
MyRecursiveTask task1 = new MyRecursiveTask(array, start, mid);
MyRecursiveTask task2 = new MyRecursiveTask(array, mid, end);
task1.fork();
task2.fork();
int a1 = task1.join();
int a2 = task2.join();
ret = a1 + a2;
} else {
for (int i = start; i < end; i++) {
ret += array[i];
}
}
return ret;
}
}

MainTest

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
//如何实现一个在ForkJoinPool类中使用的自定义的工作者线程,及如何使用一个工厂来使用它。
public class MainTest {
public static void main(String[] args) throws Exception {
MyWorkerThreadFactory factory = new MyWorkerThreadFactory();
ForkJoinPool pool = new ForkJoinPool(4, factory, null, false);
int array[] = new int[100000];
for (int i = 0; i < array.length; i++) {
array[i] = 1;
}
MyRecursiveTask task = new MyRecursiveTask(array, 0, array.length);
pool.execute(task);
task.join();
pool.shutdown();
pool.awaitTermination(1, TimeUnit.DAYS);
System.out.printf("Main: Result: %d\n", task.get());
System.out.printf("Main: End of the program\n");
}
}

Fork-Join框架的异常处理

ForkJoinTask在执行的时候可能会抛出异常,但是我们没办法在主线程里直接捕获异常,所以ForkJoinTask提供了isCompletedAbnormally()来检查任务是否已经抛出异常或已经被取消了,并且可以通过ForkJoinTaskgetException()获取异常。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package com.forkjoin2;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;

public class TaskForException extends RecursiveTask {
private static final long serialVersionUID = 1L;

@Override
protected Integer compute() {
try {
System.out.println(1 / 0);
} catch (Exception e) {
System.out.println("异常:");
//e.printStackTrace();
}
return 0;
}

@SuppressWarnings("unchecked")
public static void main(String[] args) {
ForkJoinPool forkJoinPool = new ForkJoinPool();
TaskForException task = new TaskForException();
Future result = forkJoinPool.submit(task);
try {
System.out.println(result.get());
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
boolean f=task.isCompletedAbnormally();
//处理过的异常,不会报错,同样不会进入if分支。。。
if (task.isCompletedAbnormally()) {
System.out.println("进入if");
System.out.println(task.getException());
}
}
}

Fork-Join框架的性能测试

例1

测试Fork-Join框架性能,使用线程池,for循环与Fork-Join框架作比较。
Task

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package com.fokejoin1.model;

import java.util.List;

/**
* 任务类
*/
public class Task {
/**
* 操作列表
*/
private final List<Operation> operations;
/**
* Type of task 任务类型
*/
private final TaskType taskType;

public Task(List<Operation> operations, TaskType taskType) {
this.operations = operations;
this.taskType = taskType;
}

@Override
public String toString() {
return taskType.name() + " (" + operations.size() + ")";
}

public List<Operation> getOperations() {
return operations;
}

public TaskType getTaskType() {
return taskType;
}

/**
* 任务的类型
* 任务数目大小
*/
public static enum TaskType {
XS(10), S(100), M(1000), L(10000), XL(100000), XXL(1000000);
private final int range;

TaskType(int range) {
this.range = range;
}

public int getRange() {
return range;
}
}
}

Operation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package com.fokejoin1.model;

/**
* Models an operation 操作模式
*
*/
public abstract class Operation {
// 操作类型
public enum OperationType {
Q, // Query
U // Update
}

private OperationType operationType;

public Operation(OperationType intervalType) {
this.operationType = intervalType;
}

public OperationType getIntervalType() {
return operationType;
}

/**
* Query an interval 查询操作
*/
public static class QueryIntervalOperation extends Operation {
private int left;
private int right;

public QueryIntervalOperation(int left, int right) {
super(OperationType.Q);
this.left = left;
this.right = right;
}

@Override
public String toString() {
return "" + super.operationType.name() + "[" + left + ", " + right + "]";
}

public int getLeft() {
return left;
}

public int getRight() {
return right;
}
}

/**
* Update at an index 更新操作
*/
public static class UpdateIntervalOperation extends Operation {
private int index;
private int val;

public UpdateIntervalOperation(int index, int val) {
super(OperationType.U);
this.index = index;
this.val = val;
}

public int getVal() {
return val;
}

@Override
public String toString() {
return "" + super.operationType.name() + "[@" + this.index + ", " + this.val + "]";
}

public int getIndex() {
return index;
}
}
}

TaskProducer创建任务信息的类,创建测试任务的数目,基本信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package com.fokejoin1.producer;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Scanner;

import com.fokejoin1.model.Operation;
import com.fokejoin1.model.Task;

/**
* 创建任务(任务生产者):顺序数字的任务,随机数字的任务
*/
public class TaskProducer {

/**
* 随机任务数=100
*/
private static final int RAND_SEED = 100;
private static Random randomGenerator = new Random(RAND_SEED);
/**
* 任务的文件
*/
private static final String TASK_FILENAME = "task.txt";
/**
* 从左至右边缘的间隔最大跨度
*/
private static final int MAX_INTERVAL_LENGTH = 10000;
private static final int MAX_INTERVAL_INDEX_VALUE = 10000000;
private final Scanner scanner;

/**
* 构造函数:读取任务文件
*
* @throws FileNotFoundException
*/
public TaskProducer() throws FileNotFoundException {
scanner = new Scanner(new File(TASK_FILENAME));
scanner.useDelimiter(",");
}

/**
* 从任务文件中获取任务
*/
public Task getNext() {
if (scanner == null) {
return null;
}
if (scanner.hasNext()) {
// 任务
String taskAsString = scanner.next();
return produceQueryTask(Task.TaskType.valueOf(taskAsString));
}
return null;
}

/**
* 创建查询任务
*
* @param taskType
* 任务类型
* @return 任务对象
*/
private Task produceQueryTask(Task.TaskType taskType) {
// 操作列表
List<Operation> operations = new ArrayList<Operation>();
for (int i = 0; i < taskType.getRange(); ++i) {
operations.add(getRandomQueryInterval());
}
// 构造任务
return new Task(operations, taskType);
}

/**
* 随机创建查询左右对象
*/
private Operation.QueryIntervalOperation getRandomQueryInterval() {
int left = randomGenerator.nextInt(MAX_INTERVAL_INDEX_VALUE - MAX_INTERVAL_LENGTH);
int right = left + randomGenerator.nextInt(MAX_INTERVAL_LENGTH);
return new Operation.QueryIntervalOperation(left, right);
}
}

TaskResultHandler绑定任务执行结果类,countDownLatch计数任务。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package com.fokejoin1.result;

import java.util.concurrent.CountDownLatch;

import com.fokejoin1.model.Task;

/**
* 任务结果绑定对象
*/
public class TaskResultHandler {

private CountDownLatch allTasksDoneLatch;

public TaskResultHandler(CountDownLatch allTasksDoneLatch) {
this.allTasksDoneLatch = allTasksDoneLatch;
}

public void reportQueryResult(Task task, int index, int val) {
}

public void reportUpdateResult(Task task, int index) {
}

/**
* 任务完成,计数-1
*/
public void taskDone(Task task) {
allTasksDoneLatch.countDown();
}
}

TaskSolver

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package com.fokejoin1.solver;

import com.fokejoin1.model.Operation;
import com.fokejoin1.model.Task;
import com.fokejoin1.result.TaskResultHandler;
import com.fokejoin1.rmq.RMQSegmentTree;

/**
* 任务解决类
*/
public class TaskSolver {

/**
* 用户查询,更新的对象
*/
private RMQSegmentTree segmentTree;
/**
* 任务结果绑定对象
*/
private TaskResultHandler taskResultHandler;

/**
* 构造函数,初始化对象
*
* @param segmentTree
* @param taskResultHandler
*/
public TaskSolver(RMQSegmentTree segmentTree, TaskResultHandler taskResultHandler) {
this.segmentTree = segmentTree;
this.taskResultHandler = taskResultHandler;
}

/**
* 解决任务
*/
public void solve(Task t) {
solve(t, 0, t.getOperations().size());
}

/**
*
* @param t
* 任务对象
* @param from
* 开始位置
* @param to
* 结束位置
*/
public void solve(Task t, int from, int to) {
for (int k = from; k < to; ++k) {
solve(t, k);
}
}

/**
* 任务的解决
*
* @param t
* @param k
*/
private void solve(Task t, int k) {
// 获取操作对象
Operation operation = t.getOperations().get(k);
switch (operation.getIntervalType()) {
case Q:
solveQuery(t, k, operation);
case U:
solveUpdate(t, k, operation);
}
}

/**
* 查询任务解决方法
*
* @param t
* 任务
* @param k
* 操作索引
* @param operation
* 查询操作
*/
private void solveQuery(Task t, int k, Operation operation) {
// 操作类的内部类:查询操作
Operation.QueryIntervalOperation intervalQ = (Operation.QueryIntervalOperation) operation;
int result = segmentTree.query(intervalQ.getLeft(), intervalQ.getRight());
// 任务结束,结果绑定
taskResultHandler.reportQueryResult(t, k, result);
}

/**
* 更新操作
*
* @param t
* @param k
* @param operation
*/
private void solveUpdate(Task t, int k, Operation operation) {
}

/**
* @return The task result handler
*/
public TaskResultHandler getTaskResultHandler() {
return taskResultHandler;
}

}

RMQSegmentTree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package com.fokejoin1.rmq;

/**
* 线段树
*
* Init : O(N) Query : O(log N) Update : O(log N)
*/
public class RMQSegmentTree {

/**
* tree[k] Holds the index of the smallest element from values[k_start] ..
* values[k_end] where [k_start, k_end] are the range values of node 'k'.
* Node 1 is defined as having [0 .. N] as range values Node 2 is defined as
* having [0 .. N/2] as range values Node 3 is defined as having [N/2+1 ..
* N] as range values ... Node k defined recursively using above logic
*
* tree存的是values的索引 tree的索引是node
*/
public int[] tree;
/**
* 查询或者更新的数组
*/
public int[] values;

/**
* 构造函数,初始化二叉树
*
* @param values
*/
public RMQSegmentTree(int[] values) {
int log2n = (int) (Math.log(values.length) / Math.log(2));
this.values = values;
// 在区间内的二叉树
tree = new int[1 << (log2n + 2)];
init(1, 0, values.length - 1);
}

/**
* 计算tree的节点值
*/
public void init(int node, int left, int right) {
// root节点
if (left == right) {
tree[node] = left;
} else {
int mid = (left + right) / 2;
init(2 * node, left, mid);// 左节点
init(2 * node + 1, mid + 1, right);// 右节点

int minIndexHalf1 = tree[2 * node];// 左节点
int minIndexHalf2 = tree[2 * node + 1];// 右节点
// 存最小值
tree[node] = (values[minIndexHalf1] <= values[minIndexHalf2]) ? minIndexHalf1
: minIndexHalf2;
}
}

/**
* 查询i,j之间的最小值
*
* @param node
* Id of the node
* @param left
* 左边的区间上限
* @param right
* 右边的区间下限
* @param i
* 左边的索引
* @param j
* 右边的索引
*
*/
private int query(int node, int left, int right, int i, int j) {
// 返回当前节点
if (i <= left && right <= j) {
return tree[node];
} else {
int mid = (left + right) / 2;
// 初始化左右索引的默认值,tree中的默认值
int minIndexHalf1 = -1, minIndexHalf2 = -1;
// 左边索引
if (i <= mid) {
minIndexHalf1 = query(2 * node, left, mid, i, j);
}
// 右边索引
if (j > mid) {
minIndexHalf2 = query(2 * node + 1, mid + 1, right, i, j);
}
// 返回右边的索引
if (minIndexHalf1 == -1) {
return minIndexHalf2;
}
// 返回左边的索引
if (minIndexHalf2 == -1) {
return minIndexHalf1;
}
// 返回最小的值
if (values[minIndexHalf1] <= values[minIndexHalf2]) {
return minIndexHalf1;
}
return minIndexHalf2;
}
}

/**
* 更新值
*
* @param node
* Id of the node
* @param left
* Current considered interval left index
* @param right
* Current considered interval right index
* @param i
* values的索引值
* @param val
* Value to which to update the index
*/
private void update(int node, int left, int right, int i, int val) {

if (left == right && left == i) {
tree[node] = i;
values[i] = val;
} else {
// 中间值
int mijl = (left + right) / 2;
// 更新左边
if (i <= mijl) {
update(2 * node, left, mijl, i, val);
} else {// 更新右边
update(2 * node + 1, mijl + 1, right, i, val);
}
// node的值是values中最小值
tree[node] = (values[tree[2 * node]] < values[tree[2 * node + 1]]) ? tree[2 * node]
: tree[2 * node + 1];
}
}

/**
* 指定索引范围内数组中的最小值的索引
*
* @param i
* @param j
* @return
*/
public int query(int i, int j) {
return query(1, 0, values.length - 1, i, j);
}

/**
* 更新值
*/
public void update(int i, int val) {
update(1, 0, values.length - 1, i, val);
}
}

AbstractTaskProcessor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
package com.fokejoin1.processor;

import com.fokejoin1.model.Task;
import com.fokejoin1.solver.TaskSolver;

/**
* 任务进程抽象类
*
*/
public abstract class AbstractTaskProcessor {

protected TaskSolver taskSolver;

public AbstractTaskProcessor(TaskSolver taskSolver) {
this.taskSolver = taskSolver;
}

/**
* 任务执行
*
* @param task
*/
public abstract void process(Task task);

/**
* 任务的关闭
*/
public void shutdown() {
}
}

TaskProcessorFJ

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.fokejoin1.processor;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

import com.fokejoin1.model.Task;
import com.fokejoin1.solver.TaskSolver;

/**
* This is a Fork-Join processor for tasks, using a ForkJoinPool
*/
public class TaskProcessorFJ extends AbstractTaskProcessor {

private ForkJoinPool forkJoinPool;

public TaskProcessorFJ(TaskSolver taskSolver) {
super(taskSolver);
forkJoinPool = new ForkJoinPool();
}

@Override
public void process(Task task) {
forkJoinPool.invoke(new Subtask(task, 0, task.getOperations().size(), true));
}

/**
* 包装任务类
*
*/
private class Subtask extends RecursiveAction {
private static final long serialVersionUID = 1L;

/**
* 需要解决的任务
*/
final Task task;

final int from;
final int to;

/**
* Is this subtask == the initial task
*/
final boolean rootTask;

public Subtask(Task task, int from, int to, boolean rootTask) {
this.task = task;
this.from = from;
this.to = to;
this.rootTask = rootTask;
}

@Override
protected void compute() {
// 如果是XS, S, M,则解决它
if (to - from < Task.TaskType.L.getRange()) {
taskSolver.solve(task, from, to);
} else {// 如果是L, XL, XXL,拆分任务
int mid = (from + to) / 2;
invokeAll(new Subtask(this.task, from, mid, false), new Subtask(this.task, mid + 1,
to, false));
}
//如果是XS,S,M者进入if,计数器-1
//L,XL,XXL,分支任务计数器不执行-1
//总任务计数器才-1
if (rootTask) {
taskSolver.getTaskResultHandler().taskDone(task);
}
}
}
}

TaskProcessorPool

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package com.fokejoin1.processor;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.fokejoin1.model.Task;
import com.fokejoin1.solver.TaskSolver;

/**
* This is a thread pool processor for tasks, using a fixed thread pool ExecutorService
*
* We spawn a pool of threads, and as soon as we have a job we submit it to the executor to process it
* This should run 4 threads solving tasks in parallel. This is not as efficient as a workstealing thread pool,
* because we might have one thread busy with a large task, while the others have nothing to do
*
*/
public class TaskProcessorPool extends AbstractTaskProcessor {
public final int POOL_SIZE = 4;

private ExecutorService threadPool;

public TaskProcessorPool(TaskSolver taskSolver) {
super(taskSolver);
threadPool = Executors.newFixedThreadPool(POOL_SIZE);
}

@Override
public void process(Task task) {
threadPool.submit(new TaskRunnable(task));
}

@Override
public void shutdown() {
threadPool.shutdown();
}

class TaskRunnable implements Runnable {
private Task task;

public TaskRunnable(Task task) {
this.task = task;
}

@Override
public void run() {
taskSolver.solve(task);
taskSolver.getTaskResultHandler().taskDone(task);
}
}

}

TaskProcessorSimple

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package com.fokejoin1.processor;

import com.fokejoin1.model.Task;
import com.fokejoin1.solver.TaskSolver;

/**
* Naive implementation of a processor.
* We just solve the task in the same thread, sequentially
*/
public class TaskProcessorSimple extends AbstractTaskProcessor {

public TaskProcessorSimple(TaskSolver taskSolver) {
super(taskSolver);
}

@Override
public void process(Task task) {
taskSolver.solve(task);
taskSolver.getTaskResultHandler().taskDone(task);
}

}

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×