【matplotlib】ウォーターフォールチャート

目的

社内のプレゼン資料でウォーターフォールチャートを使いたかったのでmatplotlibを使って作成してみました。
自分の担当している会社を全部まとめて「売上が上がったのか下がったのか」、「要因はなにか」を示すのに使えると思います。
浮いているところは透明の棒グラフを設置しています。
完成形は下のような形になります。

表の準備

グラフに使う表は2つで、別々で用意したほうが後で楽だと思います。

まず会社別に「前年の売上」、「要因A・B・C…」、「今年の売上」の表を作ります。

次に透明の棒グラフのためのデータを作ります。
1つ前までの値を全部足した表を作ります。このとき、両サイドの棒グラフは浮かせたくないので0にします。他にも浮かせたくないところは0にしてください。

グラフの作成

ここからmatplotlibでグラフを作成します。
まず呪文、color_listは好きな色を入れてください。

fig, ax = plt.subplots()
plt.rcParams['font.family'] = 'MS Gothic'
plt.rcParams['font.size'] = 18
plt.rcParams["figure.figsize"] = (2015)
color_list =['#00361e','#006243','#1a915d','#56c278','#8af697','#beffcd']

続いてグラフの描画
積み上げ棒グラフは、普通の棒グラフにbottomという引数を指定することで作成できます。
A、B、Cと積み上げたい場合、BのbottomはA、CのbottomはAとBになります。
今回わかりやすく全て書きましたが、ここはfor文で回してもいいかもしれません。ついでに色も指定しておきます。

ax.bar(base_df.columns, base_df.loc['base'],color='r')
ax.bar(df.columns, df.loc['A社'], label='A社', bottom=base_df.loc['base'], color=color_list[0])
ax.bar(df.columns, df.loc['B社'], label='B社', bottom=base_df.loc['base'] + df.loc['A社'], color=color_list[2])
ax.bar(df.columns, df.loc['C社'], label='C社', bottom=base_df.loc['base'] + df.loc['A社'] + df.loc['B社'], color=color_list[4])

これで下のようなグラフになります。あとは赤のところを透明にすれば完成です。

透明にするには引数alpha=0に設定すればいいので、最終的なコードは下のようになります。

ax.bar(base_df.columns, base_df.loc['base'], alpha=0)
ax.bar(df.columns, df.loc['A社'], label='A社', bottom=base_df.loc['base'], color=color_list[0])
ax.bar(df.columns, df.loc['B社'], label='B社', bottom=base_df.loc['base'] + df.loc['A社'], color=color_list[2])
ax.bar(df.columns, df.loc['C社'], label='C社', bottom=base_df.loc['base'] + df.loc['A社'] + df.loc['B社'], color=color_list[4])

あとは適当に装飾を付けて完成です。
#装飾

plt.title('2022年売上報告')
ax.legend(bbox_to_anchor=(1,1), loc='upper left')
plt.ylabel('売上金額(k)')
plt.yticks(np.arange(2000,14000,2000))
columns = df.columns
for column in columns:
    total = df[column].sum()
    if total >=0:
        y = df[column].sum() + base_df[column].sum()
        plt.text(column, y, f'+{total:,}',ha='center')
    else:
        y = base_df[column].sum()
        plt.text(column, y, f'{total:,}',ha='center',color='r')

この記事が気に入ったらサポートをしてみませんか?