見出し画像

簡単な行列積計算プログラム|行列積高速化#4

この記事は、以下の記事を分割したものです。
[元の記事]行列積計算を高速化してみる
一括で読みたい場合は、元の記事をご覧ください。

チューニングに必要な準備ができたので、行列積計算プログラムを作成していきます。まず、BLASの行列積ルーチンDGEMMの公式を再確認しておきましょう。

C = alpha * A * B + beta * C

ここで、A, B, Cは行列、alpha, betaはスカラー係数を表しています。

関数インターフェースは、上記のテストプログラムで使用できるように、CBLASのcblas_dgemm関数と同じものとし、関数名をmyblas_dgemmとします。

void myblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
                 const int K, const double alpha, const double  *A,
                 const int lda, const double  *B, const int ldb,
                 const double beta, double  *C, const int ldc);

myblas_dgemmでは、第一引数Orderの処理と引数のエラー処理だけを行い、Orderの違いを吸収してしまいます。第一引数Orderの処理は、CblasRowMajorが選ばれた場合に、行列Aと行列Bを入れ替えるという処理をします。これは、cblas_dgemm関数の実装を参照にしています。

実際の計算には、myblas_dgemm_mainという別の関数を用意しています。

void myblas_dgemm_main( gemm_args_t* args );

引数はgemm_args_t型の構造体にまとめることにしました。これは、高速化作業で引数を追加する必要が出てきた時、引数の変更作業を楽にするためです。構造体の定義は、下記の通りです。

typedef struct _gemm_args_t {
       size_t       TransA;
       size_t       TransB;
       size_t       M;
       size_t       N;
       size_t       K;
       double       alpha;
       const double *A;
       size_t       lda;
       const double *B;
       size_t       ldb;
       double       beta;
       double       *C;
       size_t       ldc;
} gemm_args_t;

転置設定のTransAとTransBは、元々enum CBLAS_TRANSPOSE型でしたが、今後のためを考えてビットフラグに変更しました。フラグは、転置を1ビット目、複素共役を2ビット目とし、下記のようなマスクを用意しています。

#define  MASK_TRANS  0x01

#define  MASK_CONJ   0x02


4-1. myblas_dgemm関数

行列Aと行列Bの入れ替えが必要なCblasRowMajorの場合の引数処理及びエラー処理は、下記のようになります。CblasColMajorの場合は、TransA, TransBの処理やM,Nの入れ替えを元に戻したコードを通るように条件分岐しています。

       gemm_args_t args={0,0,0,0,0,0e0,NULL,0,NULL,0,0e0,NULL,0};

       int info = 0;

       if( Order == CblasColMajor ){
       
            /** 省略 **/
       
       }else if( Order == CblasRowMajor ){

               // Transpose Set-up

               if( TransA == CblasNoTrans  ){ args.TransB = 0; }
               if( TransA == CblasTrans    ){ args.TransB = MASK_TRANS; }
               if( TransA == CblasConjTrans){ args.TransB = MASK_TRANS | MASK_CONJ; }

               if( TransB == CblasNoTrans  ){ args.TransA = 0; }
               if( TransB == CblasTrans    ){ args.TransA = MASK_TRANS; }
               if( TransB == CblasConjTrans){ args.TransA = MASK_TRANS | MASK_CONJ; }

               // Error Check

               if( C    == NULL ) info=13;
               if( B    == NULL ) info=10;
               if( A    == NULL ) info= 8;
               if( info ){ myblas_xerbla("myblas_dgemm",info); return; }

               int ma = ( (args.TransB & MASK_TRANS) ? M : K );
               int mb = ( (args.TransA & MASK_TRANS) ? K : N );
               if( ldc  < N     ) info=14;
               if( ldb  < mb    ) info=11;
               if( lda  < ma    ) info= 9;
               if( K    < 0     ) info= 6;
               if( N    < 0     ) info= 5;
               if( M    < 0     ) info= 4;
               if( info ){ myblas_xerbla("myblas_dgemm",info); return; }

               // No Computing 

               if( M    == 0    ) return;
               if( N    == 0    ) return;
               if( K    == 0   && beta == 1e0  ) return;
               if( alpha== 0e0 && beta == 1e0  ) return;

               // Set arguments

               args.M     = N;
               args.N     = M;
               args.K     = K;
               args.alpha = alpha;
               args.A     = B;
               args.lda   = ldb;
               args.B     = A;
               args.ldb   = lda;
               args.beta  = beta;
               args.C     = C;
               args.ldc   = ldc;

       }
       
       myblas_dgemm_main( &args );

ここで、myblas_xerbla関数は、BLASのエラー出力ルーチンXERBLAを模したものです。下記のような、BLASのエラーメッセージを出力します。

void myblas_xerbla( const char* name, int info ){
 printf(" ** On entry to %s parameter number %d had an illegal value\n",name,info);
}


4-2. myblas_dgemm_main関数

実際の行列積計算は、行列Aと行列Bの転置有無の組み合わせで、全部で4パターンのコードが必要になります。転置なしの場合は、次のようなコードになります。

                for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A += lda;
                          B++;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - lda*K + 1;
                       B = B - K;
                       C++;
                   }
                   A = A - M;
                   B = B + ldb;
                   C = C - M + ldc;
               }

行列CがM×N行列(Mがメモリ連続方向)で、Kに依存しないので、無駄なメモリアクセスを避けるために、外側からN→M→Kの順の三重ループにしています。本来、CはM*N*K回のメモリアクセスが必要ですが、この順にすると、M*N回のメモリアクセスで済みます。また、例えばM→N→Kの順にすると、Cはストライドアクセスが必要になります。これを避けるため、N→M→Kの順にしています。

一方、行列AはM×K行列なので、Kループはメモリ連続方向ではありません。このため、ldaずつ飛び飛びのストライドアクセスになってしまいます。逆に、行列BはK×N行列のため、Kループがメモリ連続方向なので連続アクセスになります。行列Aがどうしてもストライドアクセスになるため、スピードがほとんど出ないと予想されます。

実際、計算速度を測定してみると、次のようになりました。

Max  Peak MFlops per Core: 52800 MFlops 
Base Peak MFlops per Core: 46400 MFlops 
size  , elapsed time[s],          MFlops,   base ratio[%],    max ratio[%] 
   16,     4.05312E-06,         2210.64,         4.76432,         4.18683 
   32,     3.40939E-05,         2012.33,         4.33691,         3.81123 
   64,     0.000264883,         2025.71,         4.36575,         3.83657 
  128,      0.00234389,         1810.43,         3.90179,         3.42885 
  256,       0.0287192,         1175.21,         2.53278,         2.22577 
  512,        0.263068,         1023.39,         2.20559,         1.93824 
 1024,         6.27475,         342.743,        0.738671,        0.649135 
 2048,         113.116,          151.99,        0.327564,         0.28786 

M=N=K=2048で、基本理論ピーク性能の0.3%しか出ていません。どうしようもなく、遅いですね…。

ちなみに、4パターン内で最も高速なのは、行列Aだけを転置した場合で、実装コードは次のようになります。

                for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A++;
                          B++;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - K + lda;
                       B = B - K;
                       C++;
                   }
                   A = A - lda*M;
                   B = B + ldb;
                   C = C - M + ldc;
               }

行列Aを転置した場合は、最内のKループにおいて、AもBも連続アクセスになっています。この時の、計算速度は次のようになります。

Max  Peak MFlops per Core: 52800 MFlops 
Base Peak MFlops per Core: 46400 MFlops 
size  , elapsed time[s],          MFlops,   base ratio[%],    max ratio[%] 
   16,     4.05312E-06,         2210.64,         4.76432,         4.18683 
   32,     3.00407E-05,         2283.83,         4.92205,         4.32544 
   64,     0.000258923,         2072.34,         4.46625,         3.92489 
  128,       0.0022769,          1863.7,          4.0166,         3.52974 
  256,       0.0209579,         1610.42,         3.47073,         3.05003 
  512,        0.178289,         1510.03,         3.25438,         2.85991 
 1024,         1.40407,         1531.71,          3.3011,         2.90097 
 2048,         11.1541,         1541.35,         3.32188,         2.91923 

この場合は、基本周波数の理論ピーク性能比で3.3%以上のスピードが出ます。

ということで、ストライドアクセスだととても遅いことがわかりますね。

4-3. 初期プログラム

最後に、初期プログラムの全体を載せておきます。

void myblas_dgemm_main( gemm_args_t* args ){

       size_t       TransA = args->TransA;
       size_t       TransB = args->TransB;
       size_t       M      = args->M;
       size_t       N      = args->N;
       size_t       K      = args->K;
       double       alpha  = args->alpha;
       const double *A     = args->A;
       size_t       lda    = args->lda;
       const double *B     = args->B;
       size_t       ldb    = args->ldb;
       double       beta   = args->beta;
       double       *C     = args->C;
       size_t       ldc    = args->ldc;
       
       double       AB;
       
       if( TransA & MASK_TRANS ){
           if( TransB & MASK_TRANS ){
               for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A++;
                          B+=ldb;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - K + lda;
                       B = B - ldb*K;
                       C++;
                   }
                   A = A - lda*M;
                   B = B + 1;
                   C = C - M + ldc;
               }
           }else{
               for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A++;
                          B++;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - K + lda;
                        B = B - K;
                       C++;
                   }
                   A = A - lda*M;
                   B = B + ldb;
                   C = C - M + ldc;
               }
           }
       }else{
           if( TransB & MASK_TRANS ){
               for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A += lda;
                          B += ldb;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - lda*K + 1;
                       B = B - ldb*K;
                       C++;
                   }
                   A = A - M;
                   B = B + 1;
                   C = C - M + ldc;
               }
           }else{
               for( size_t j=0; j<N; j++ ){
                   for( size_t i=0; i<M; i++ ){
                       AB=0e0;
                       for( size_t k=0; k<K; k++ ){
                          AB = AB + (*A)*(*B);
                          A += lda;
                          B++;
                       }
                       *C=beta*(*C) + alpha*AB;
                       A = A - lda*K + 1;
                       B = B - K;
                       C++;
                   }
                   A = A - M;
                   B = B + ldb;
                   C = C - M + ldc;
               }
           }
       }
}

これで、ようやく準備が整いました。以降では、高速化の方法を実際にプログラムを書きながら解説していこうと思います。


次の記事

元の記事はこちらです。

ソースコードはGitHubで公開しています。


この記事が気に入ったらサポートをしてみませんか?