メモ:java.util.concurrent.RecursiveTaskのFibonacci
Java7で導入されたFork/Joinフレームワークの中に、java.util.concurrent.RecursiveTask
というクラスがあります。これは、再帰を使用して計算を行った結果を生成するためのForkJoinTask
の実装です。
このクラスのAPIドキュメントのページを見ると、例としてフィボナッチ数列の計算が書かれています。
準備として、ログ出力メソッドを定義しておきます。
static void log(String fmt, Object... args) { String now = String.format("%1$tT.%1$tL", System.currentTimeMillis()); now = "--:" + now.substring(3); System.out.printf("%s [%s] %s%n", now, Thread.currentThread().getName(), String.format(fmt, args)); }
今回は使用していませんが、forkされたタスクをトレースするためにスレッド名を表示できるようになっています。(2015-03-06の追記のついでにsimplifyしました。)
それでは、APIドキュメントとほぼ同じバージョンを実行してみます。
- version 1: APIドキュメントとほぼ同じ
// import java.util.concurrent.*; // ForkJoinPool, RecursiveTask @SuppressWarnings("serial") class Fibonacci extends RecursiveTask<Integer> { final int n; Fibonacci(int n) { this.n = n; } @Override protected Integer compute() { // 2015-01-31 ちょっと修正 if (n <= 1) return n; Fibonacci f1 = new Fibonacci(n - 1); f1.fork(); return new Fibonacci(n - 2).compute() + f1.join(); } } ForkJoinPool pool = new ForkJoinPool(3); log("version 1 start"); for (int i = 0; i <= 40; i++) log("fib(%d) = %d", i, pool.invoke(new Fibonacci(i)));
n - 2
の計算は自身のスレッドが再帰的に処理し、n - 1
の計算は、forkしたスレッドで処理してそれを待ち合わせ、その合計を返します。
- 実行結果
--:32:43.958 [main] version 1 start --:32:44.045 [main] fib(0) = 0 --:32:44.045 [main] fib(1) = 1 --:32:44.046 [main] fib(2) = 1 --:32:44.046 [main] fib(3) = 2 --:32:44.047 [main] fib(4) = 3 --:32:44.047 [main] fib(5) = 5 --:32:44.048 [main] fib(6) = 8 --:32:44.048 [main] fib(7) = 13 --:32:44.049 [main] fib(8) = 21 --:32:44.049 [main] fib(9) = 34 --:32:44.050 [main] fib(10) = 55 --:32:44.050 [main] fib(11) = 89 --:32:44.051 [main] fib(12) = 144 (中略) --:32:52.724 [main] fib(37) = 24157817 --:32:57.977 [main] fib(38) = 39088169 --:33:06.858 [main] fib(39) = 63245986 --:33:21.292 [main] fib(40) = 102334155
fib(0)
~fib(40)
までを計算すると、38秒*1ほどかかりました。
毎回計算するので当然ですね。コンストラクターが呼ばれた回数は866,988,831回でした。
また、int
なのでfib(47)
でオーバーフローします。
これを、キャッシュとBigInteger
に対応してみましょう。
- version 2: version 1 の BigInteger対応 + キャッシュ版
// HashMap#computeIfAbsentを使用するのでJava8以降 // import static java.math.BigInteger.*; // ZERO, ONE // import java.math.BigInteger; // import java.util.HashMap; // import java.util.concurrent.*; // ForkJoinPool, RecursiveTask // static final BigInteger TWO = BigInteger.valueOf(2); HashMap<BigInteger, BigInteger> cache = new HashMap<>(); cache.put(ZERO, ZERO); cache.put(ONE, ONE); @SuppressWarnings("serial") class Fibonacci extends RecursiveTask<BigInteger> { final BigInteger n; Fibonacci(BigInteger n) { this.n = n; } @Override protected BigInteger compute() { assert n.min(ZERO).equals(ZERO); return cache.computeIfAbsent(n, x -> { Fibonacci f1 = new Fibonacci(n.subtract(ONE)); f1.fork(); return new Fibonacci(n.subtract(TWO)).compute().add(f1.join()); }); } } ForkJoinPool pool = new ForkJoinPool(3); log("version 2 start"); for (int i = 500; i <= 505; i++) log("fib(%d) = %d", i, pool.invoke(new Fibonacci(BigInteger.valueOf(i))));
HashMap#computeIfAbsent
を使って、値が存在しないときだけ計算を実行するようにしています。
並列処理だからといってConcurrentHashMap
を使ってしまうと、途中から計算した場合に枝が増え過ぎてデッドロックが発生してしまいます。よって、ここでは通常のHashMap
を使用します。
追記(2015-03-06)
@yohhoyさんからツイートいただきました。
上の例と同じように、FibonacciのキャッシュをConcurrentHashMap
を使って実装した場合の問題に関連して、興味深い情報が記載されています。ありがとうございました。
Recursive ConcurrentHashMap.computeIfAbsent() call never terminates. Bug or “feature”? http://t.co/lz4uYhhDla fixed?
— yoh (@yohhoy) 2015, 3月 6
余談ですが、もう一度確認してみようと思ってversion2を実行したら、stack overflowになったという冗談みたいな落ちが。ごめんなさい。
(追記ここまで)
それでは、一気にfib(500)
あたりを数件実行してみます。
- 実行結果
--:37:00.881 [main] version 2 start --:37:01.063 [main] fib(500) = 139423224561697880139724382870407283950070256587697307264108962948325571622863290691557658876222521294125 --:37:01.064 [main] fib(501) = 225591516161936330872512695036072072046011324913758190588638866418474627738686883405015987052796968498626 --:37:01.065 [main] fib(502) = 365014740723634211012237077906479355996081581501455497852747829366800199361550174096573645929019489792751 --:37:01.065 [main] fib(503) = 590606256885570541884749772942551428042092906415213688441386695785274827100237057501589632981816458291377 --:37:01.066 [main] fib(504) = 955620997609204752896986850849030784038174487916669186294134525152075026461787231598163278910835948084128 --:37:01.067 [main] fib(505) = 1546227254494775294781736623791582212080267394331882874735521220937349853562024289099752911892652406375505
1秒未満で完了しました!
(おわり)