AtCoder ABC293 E - Geometric Progression

※ a ≦ i ≦ bにおけるA^iの和をsum(a, b)と表現することにする
  例えば2から5までのA^iの和は
    A^2 + A^3 + A^4 + A^5 = sum(2, 5)
  と表現する

考えたこと

・Xが2のN乗なら簡単

例えばX = 8ならsum(0, 7)を求めればいい。このとき
  A^4 = A^0 * A^4
  A^5 = A^1 * A^4
  A^6 = A^2 * A^4
  A^7 = A^3 * A^4
なので、4から7までの和sum(4, 7)は
  sum(4, 7) = A^4 * (A^0 + A^1 + A^2 + A^3)
       = A^4 * sum(0, 3)
で求められる。よってsum(0, 7)は
  sum(0, 7) = sum(0, 3) + sum(4, 7)
       = sum(0, 3) + A^4 * sum(0, 3)
となる。同様にsum(0, 3), sum(0, 1)も
  sum(0, 3) = sum(0, 1) + sum(2, 3)
       = sum(0, 1) + A^2 * sum(0, 1)
  sum(0, 1) = sum(0, 0) + sum(1, 1)
       = sum(0, 0) + A^1 * sum(0, 0)
と求めることができる。sum(0, 0)はA^0で1なので、sum(0, 7)を求める式中のすべての項をA^iで表現することができた。この方法なら計算量はO(logX)で済む。

・端数をどうするか

X = 11の場合を考える。sum(0, 10)を求めればいいが
  sum(0, 10) = sum(0, 7) + sum(8, 10)
となるので、sum(8, 11)の部分が求められればいいことになる。ここで上記の2のN乗のときと同様に
  A^8 = A^0 * A^8
  A^9 = A^1 * A^8
  A^10 = A^2 * A^8
と考えると、sum(8, 10)は
  sum(8, 10) = A^8 * (A^0 + A^1 + A^2)
        = A^8 * sum(0, 2)
となる。あとは同様に
  sum(0, 2) = sum(0, 1) + sum(2, 2)
  sum(2, 2) = A^2 * sum(0, 0)
というように、2のN乗問題に帰着させることができる。

sum(0, 10)を求めるためにsum(0, 7)とsum(0, 2)が必要で、
sum(0, 2)を求めるためにsum(0, 0)とsum(0, 1)が必要で、
sum(0, 1)を求めるためにsum(0, 0)が必要。
という格好になっているので、再帰関数で書くと楽そうである。

・どんな再帰にするか

上記の通りsum(0, 10)を求めるためにsum(0, 7)とsum(0, 2)に分ければいいので、再帰関数にxが入力されたとすると次のような処理をすればいい。
(1) xを超えない範囲で最大の2のN乗を求める
   ここで求めた2のN乗をeと呼ぶことにする
(2) e - 1を入力に与えて再帰関数を実行する
   sum(0, 10)におけるsum(0, 7)側の処理にあたる
(3) x - eを入力に与えて再帰関数を実行する
   sum(0, 10)におけるsum(0, 2)側の処理にあたる
(4) (2)の結果 + (3)の結果 × A^eを返す

こんな感じでいい。何度も同じ入力で再帰関数を実行することになるので、一度計算した結果は関数の外で定義した配列に保存しておいて、二度目の実行時には配列を参照するだけにするといい。(いわゆるメモ化再帰)

書いたコードと提出結果

#include <bits/stdc++.h>

std::map< long long, long long > memo;
std::map< long long, long long > twoN;
long long mod;

// 0からendまでの和を求める
long long sum(long long end){
    // endが0ならば0から0までの和になるので、Aの0乗のみである
    if(end == 0){
        return 1ll;
    }
    
    // すでに計算済みならば同じ計算を何度もやる必要はないので、覚えている結果を返す
    if(memo.find(end) != memo.end()){
        return memo[end];
    }
    
    // endを超えない範囲で最大の2のN乗を求める
    long long lbit = 60ll;
    while( ( (1ll << lbit) & end ) == 0){
        lbit--;
    }
    long long e = (1ll << lbit);
    
    long long sum1 = sum(end - e);
    long long sum2 = sum(e - 1ll);
    
    long long ret = ( (sum1 * twoN[e]) % mod + sum2) % mod;
    // 後で使うために計算結果を覚えておく
    memo[end] = ret;
    return ret;
}


int main(){
    long long A, X, M;
    std::cin >> A >> X >> M;
    mod = M;
    
    if(M == 1){
        std::cout << 0 << std::endl;
        return 0;
    }
    
    // Aの1乗, 2乗, 4乗, 8乗, …は何度も使うのであらかじめ計算しておく
    // Aの2N乗 = AのN乗 ×AのN乗
    twoN[0] = 1ll;
    twoN[1] = A;
    for(int i=1; i<=60; i++){
        twoN[1ll << i] = (twoN[1ll << (i-1)] * twoN[1ll << (i-1)]) % M;
    }
    
    std::cout << sum(X-1) << std::endl;
    
    return 0;
}

終わりに

再帰関数って書くの難しいよね。仕事じゃめったに書かないし。

この記事が気に入ったらサポートをしてみませんか?