見出し画像

Pythonでいろんな確率分布の乱数生成器を作るためのTips

少し前に統計的仮説検定の謎シミュレーターに関する記事を書きました。

そのシミュレーターでは、指定した確率分布から標本を作成し、その標本同士で検定計算を行うということをしております。

確率分布は複数個用意する必要があるため、全てを愚直に書いていくと物凄いコード量になってしまうでしょう。そうなると、シンプルに実装が大変ですし、一部の処理を変更したいときに全部の確率分布分のコードを変更することになるのでかなりストレスになります。

実際、データサイエンスを学び始めた方は過去に作ったコードを見返しても意味が分からなかったり、何故か動かなかったりするのではないでしょうか。(私も学生のときに書いていたコードはそうなっていました…)

今回の記事では、そうならないようにどのような記述をしたかを軽く書いていきます。

紹介する範囲のコードは↓です。MITなんで、好きにインストールして好きに遊んでください!

※この記事では、クラスの知識を利用します。クラスについて分からない方は公式リファレンスを一読するとより理解が深まると思います。


基底クラス

乱数生成器では、確率分布以外の部分では同じ処理を行います。そのため、それら共通の処理をまとめた基底クラスを定義しておくとかなり楽にコードを変更したり、新しい確率分布の乱数生成器を作成するときにコードを書かなくて済みます。

実際に確率計算を行う際にはscipy.statsを活用しました。確率分布から乱数を生成したり、累積確率を算出するメソッドが共通しており、とても便利です。

私が作ったシミュレーターでは、次のような基底クラスを作成しました。

class DistGenerator:
    """distribution generator

    Attributes:
        sample_size (int): sample size
        funcs (GeneratorFunc): functions for generator
        func_param (dict[str, int | float | np.ndarray]): functions parameters
        plot_prob (float): plot x axis range of distribution function
        dist_name (str): distribution name
        dist_type (Literal["discrete", "continue"]): distribution type
    """

    def __init__(self, sample_size: int = 50) -> None:
        """intiation

        Args:
            sample_size (int, optional): sample size. Defaults to 50.
        """
        self.sample_size = sample_size

        self.funcs: GeneratorFunc
        self.func_param: dict[str, int | float | np.ndarray]
        self.plot_prob: float = 0.999

        self.dist_name: str
        self.dist_type: Literal["discrete", "continue"]

    def create_sample(self) -> np.ndarray:
        """create sample by distribution

        Returns:
            np.ndarray: sample
        """
        param = self.func_param.copy()
        param.update(size=self.sample_size)

        return self.funcs["rand"](**param)

    def plot_range(self) -> tuple[float, float]:
        """return plot range

        Raises:
            NotImplementedError: must override

        Returns:
            tuple[float, float]: x axis range for graph
        """
        raise NotImplementedError("Must override!!")

    def density_points(self, delta: int = 100) -> dict[str, np.ndarray]:
        """return density points

        Args:
            delta (int, optional): points size. Default is 100.
        """
        ranges = self.plot_range()

        x_vec = np.linspace(ranges[0], ranges[1], delta)
        param = self.func_param.copy()
        param.update(x=x_vec)
        y_vec = self.funcs["stat_prob"](**param)

        return {"x": x_vec, "y": y_vec}

インスタンス変数

乱数を生成し、どのような確率分布であるか描画するための各種変数をインスタンス変数として保持します。

  • sample_size: 標本数

  • funcs: 関数

※GeneratorFuncは、自前で作った辞書型のことです。Callableとは関数の型のことです。あまり知られていませんが、変数としてnp.sumなどの関数を定義することができます。

class GeneratorFunc(TypedDict):
    """generator function

    Attributes:
        rand (Callable): create sample function
        stat_prob (Callable): caluculate probability function
    """

    rand: Callable
    stat_prob: Callable
  • func_param: 関数で使用するパラメータ

  • plot_prob: 確率分布の描画範囲(正規分布のような定義域が$${[- \infty, \infty]}$$であるものは、全範囲を描画することができないため、面積のN%分描画するか決定する)

  • dist_name: 確率分布の名前

  • dist_type: 確率分布の種類(離散か連続か)

メソッドオブジェクト

実際に乱数を生成したり、描画用のデータを記述するコードです。継承先で各種確率分布の関数や設定をself.funcsやself.func_paramに入れることで新たにコードを書かなくても機能が使えるようにします。

  • create_sample: 乱数を生成する関数

    def create_sample(self) -> np.ndarray:
        """create sample by distribution

        Returns:
            np.ndarray: sample
        """
        param = self.func_param.copy()
        param.update(size=self.sample_size)

        return self.funcs["rand"](**param)
  • plot_range: 確率分布の描画用定義域を求める関数

※NotImplementedErrorとは、関数のオーバーライドを前提とするエラーです。基底クラスのままこの関数を実行すると、エラーで処理が終了します。
※今見たらpropertyでいいですねこれ…

    def plot_range(self) -> tuple[float, float]:
        """return plot range

        Raises:
            NotImplementedError: must override

        Returns:
            tuple[float, float]: x axis range for graph
        """
        raise NotImplementedError("Must override!!")
  • density_points: 定義域に対して、描画のポイントを生成する関数

    def density_points(self, delta: int = 100) -> dict[str, np.ndarray]:
        """return density points

        Args:
            delta (int, optional): points size. Default is 100.
        """
        ranges = self.plot_range()

        x_vec = np.linspace(ranges[0], ranges[1], delta)
        param = self.func_param.copy()
        param.update(x=x_vec)
        y_vec = self.funcs["stat_prob"](**param)

        return {"x": x_vec, "y": y_vec}

継承先クラス

ここまで基底クラスを作れば、後は楽して様々な確率分布の生成器を作ることができます。

class NormGenerator(DistGenerator):
    """norm distribution generator"""

    def __init__(self, sample_size: int = 50, mu: float = 0, sigma: float = 1) -> None:
        """initiation

        Args:
            sample_size (int, optional): sample size. Defaults to 50.
            mu (float, optional): mean of population distribution. Defaults to 0.
            sigma (float, optional): standard deviation of population distribution.
                Defaults to 1.
        """
        super().__init__(sample_size)

        self.funcs = {
            "rand": stats.norm.rvs,
            "stat_prob": stats.norm.pdf,
        }
        self.func_param = {"loc": mu, "scale": sigma}

        self.dist_name = f"Norm({mu}, {sigma})"
        self.dist_type = "continue"

    def plot_range(self) -> tuple[float, float]:
        """return plot range

        Returns:
            tuple[float, float]: x axis range for graph
        """
        param = self.func_param.copy()
        param.update(confidence=self.plot_prob)

        return stats.norm.interval(**param)

基底クラスよりコード量がかなり短いことが分かります。ここでやっていることは以下の二つだけです。これでも、継承しているので、乱数を生成する機能は備わっています。

  • __init__内で確率分布独自の関数やパラメータを設定

  • plot_range関数の作成

他の確率分布も同じ方法で実装することができます。新規の確率分布の乱数生成器も10分程度で作成することができるようになりました。これでかなり快適に開発を進めることができます。

まとめると、以下の3つが重要となります。

  • 同じ処理は基底クラスにまとめてしまう

  • 関数も変数として上手くパイプラインに組み込む

  • 個別の設定が必要な部分のみ、継承クラスで実装する

実は、シミュレーターのもう一つの軸である検定計算を行う部分も同じ考えでコード量を圧縮しています。良かったら見てみて、どの部分を作りこんでいるかを考えてみてください!

ということで、AIエンジニアになりたての私が何を考えてコードを書いているかを一部紹介する記事でした。Pythonを書いている人でも、保守性や可読性を考慮したコードを書ける人は少ないと思うので、こういうことができるだけでも大きな差別化になります。

理論の勉強ももちろん大事ですが、たまにはコーディングの勉強もしてみるといいよというのが伝われば幸いです。

最後までお読みいただきありがとうございました。

この記事が気に入ったらサポートをしてみませんか?