argius note

プログラミング関連

メモ:java.util.concurrent.RecursiveTaskのFibonacci

Java7で導入されたFork/Joinフレームワークの中に、java.util.concurrent.RecursiveTaskというクラスがあります。これは、再帰を使用して計算を行った結果を生成するためのForkJoinTaskの実装です。
このクラスのAPIドキュメントのページを見ると、例としてフィボナッチ数列の計算が書かれています。


準備として、ログ出力メソッドを定義しておきます。

static void log(String fmt, Object... args) {
    String now = String.format("%1$tT.%1$tL", System.currentTimeMillis());
    now = "--:" + now.substring(3);
    System.out.printf("%s [%s] %s%n", now, Thread.currentThread().getName(), String.format(fmt, args));
}

今回は使用していませんが、forkされたタスクをトレースするためにスレッド名を表示できるようになっています。(2015-03-06の追記のついでにsimplifyしました。)



それでは、APIドキュメントとほぼ同じバージョンを実行してみます。

  • version 1: APIドキュメントとほぼ同じ
// import java.util.concurrent.*; // ForkJoinPool, RecursiveTask

@SuppressWarnings("serial")
class Fibonacci extends RecursiveTask<Integer> {
    final int n;
    Fibonacci(int n) {
        this.n = n;
    }
    @Override
    protected Integer compute() { // 2015-01-31 ちょっと修正
        if (n <= 1)
            return n;
        Fibonacci f1 = new Fibonacci(n - 1);
        f1.fork();
        return new Fibonacci(n - 2).compute() + f1.join();
    }
}
ForkJoinPool pool = new ForkJoinPool(3);
log("version 1 start");
for (int i = 0; i <= 40; i++)
    log("fib(%d) = %d", i, pool.invoke(new Fibonacci(i)));


n - 2の計算は自身のスレッドが再帰的に処理し、n - 1の計算は、forkしたスレッドで処理してそれを待ち合わせ、その合計を返します。

  • 実行結果
--:32:43.958 [main] version 1 start
--:32:44.045 [main] fib(0) = 0
--:32:44.045 [main] fib(1) = 1
--:32:44.046 [main] fib(2) = 1
--:32:44.046 [main] fib(3) = 2
--:32:44.047 [main] fib(4) = 3
--:32:44.047 [main] fib(5) = 5
--:32:44.048 [main] fib(6) = 8
--:32:44.048 [main] fib(7) = 13
--:32:44.049 [main] fib(8) = 21
--:32:44.049 [main] fib(9) = 34
--:32:44.050 [main] fib(10) = 55
--:32:44.050 [main] fib(11) = 89
--:32:44.051 [main] fib(12) = 144

(中略)

--:32:52.724 [main] fib(37) = 24157817
--:32:57.977 [main] fib(38) = 39088169
--:33:06.858 [main] fib(39) = 63245986
--:33:21.292 [main] fib(40) = 102334155

fib(0)fib(40)までを計算すると、38秒*1ほどかかりました。
毎回計算するので当然ですね。コンストラクターが呼ばれた回数は866,988,831回でした。

また、intなのでfib(47) でオーバーフローします。



これを、キャッシュとBigIntegerに対応してみましょう。

  • version 2: version 1 の BigInteger対応 + キャッシュ版
// HashMap#computeIfAbsentを使用するのでJava8以降

// import static java.math.BigInteger.*; // ZERO, ONE
// import java.math.BigInteger;
// import java.util.HashMap;
// import java.util.concurrent.*; // ForkJoinPool, RecursiveTask

// static final BigInteger TWO = BigInteger.valueOf(2);

HashMap<BigInteger, BigInteger> cache = new HashMap<>();
cache.put(ZERO, ZERO);
cache.put(ONE, ONE);
@SuppressWarnings("serial")
class Fibonacci extends RecursiveTask<BigInteger> {
    final BigInteger n;
    Fibonacci(BigInteger n) {
        this.n = n;
    }
    @Override
    protected BigInteger compute() {
        assert n.min(ZERO).equals(ZERO);
        return cache.computeIfAbsent(n, x -> {
            Fibonacci f1 = new Fibonacci(n.subtract(ONE));
            f1.fork();
            return new Fibonacci(n.subtract(TWO)).compute().add(f1.join());
        });
    }
}
ForkJoinPool pool = new ForkJoinPool(3);
log("version 2 start");
for (int i = 500; i <= 505; i++)
    log("fib(%d) = %d", i, pool.invoke(new Fibonacci(BigInteger.valueOf(i))));

HashMap#computeIfAbsentを使って、値が存在しないときだけ計算を実行するようにしています。
並列処理だからといってConcurrentHashMapを使ってしまうと、途中から計算した場合に枝が増え過ぎてデッドロックが発生してしまいます。よって、ここでは通常のHashMapを使用します。


追記(2015-03-06)

@yohhoyさんからツイートいただきました。
上の例と同じように、FibonacciのキャッシュをConcurrentHashMapを使って実装した場合の問題に関連して、興味深い情報が記載されています。ありがとうございました。


余談ですが、もう一度確認してみようと思ってversion2を実行したら、stack overflowになったという冗談みたいな落ちが。ごめんなさい。

(追記ここまで)


それでは、一気にfib(500)あたりを数件実行してみます。

  • 実行結果
--:37:00.881 [main] version 2 start
--:37:01.063 [main] fib(500) = 139423224561697880139724382870407283950070256587697307264108962948325571622863290691557658876222521294125
--:37:01.064 [main] fib(501) = 225591516161936330872512695036072072046011324913758190588638866418474627738686883405015987052796968498626
--:37:01.065 [main] fib(502) = 365014740723634211012237077906479355996081581501455497852747829366800199361550174096573645929019489792751
--:37:01.065 [main] fib(503) = 590606256885570541884749772942551428042092906415213688441386695785274827100237057501589632981816458291377
--:37:01.066 [main] fib(504) = 955620997609204752896986850849030784038174487916669186294134525152075026461787231598163278910835948084128
--:37:01.067 [main] fib(505) = 1546227254494775294781736623791582212080267394331882874735521220937349853562024289099752911892652406375505

1秒未満で完了しました!



(おわり)

*1:Windows7 32bit, CPU=Core2Duo 3.16GHzで実行。