見出し画像

【Go】AWS Lambda関数の共通処理をミドルウェアに実装する

こんにちは、はしるとりです。
ナビタイムジャパンでSREを担当しています。

社内のいくつかのプロダクトでは、API Gateway + AWS Lambda + Go で開発・運用しています。
Lambdaハンドラーにはたいていの場合、入力値の検証やログ出力、エラーハンドリングを実装することになると思います。
API GatewayのエンドポイントごとにLambda関数を作成しようとすると、各ハンドラーに同じ実装を繰り返し書くことになります。
Webアプリケーションフレームワークであれば、こういった処理はミドルウェアに実装したいと考えるでしょう。
ミドルウェアを実装することにより、ハンドラーのビジネスロジックと関心事以外を分離することができます。

今回は、サードパーティのライブラリを使用せずに、Lambdaの各ハンドラーにミドルウェアを適用して共通処理を分離する方法を紹介したいと思います。


ミドルウェアについて

本題に入る前に、標準ライブラリのnet/httpでWebサーバーを作成するときのミドルウェアの実装パターンについておさらいします。すでに知ってるよ!という方はこの章は読み飛ばしていただいて大丈夫です。

ミドルウェアはhttp.Handlerを引数にとり、http.Handlerを返す関数を作ることで、引数のhttp.Handlerの前後に処理を挟み込んだり、ミドルウェアをさらに別のミドルウェアでラップすることができます。
典型的なパターンは以下のようになります。

package main

import (
	"fmt"
	"net/http"
)

func middleware1(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fmt.Println("middleware 1 start...")
		next.ServeHTTP(w, r)
		fmt.Println("middleware 1 end")
	})
}

func middleware2(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fmt.Println("middleware 2 start...")
		next.ServeHTTP(w, r)
		fmt.Println("middleware 2 end")
	})
}

func theHandler(w http.ResponseWriter, r *http.Request) {
	fmt.Println("handler start...")
	w.Write([]byte("Hello!"))
	fmt.Println("handler end")
}

func main() {
	http.Handle("/", middleware1(middleware2(http.HandlerFunc(theHandler))
	http.ListenAndServe(":8080", nil)
}

// => curl http://localhost:8080
// middleware 1 start...
// middleware 2 start...
// handler start...
// handler end
// middleware 2 end
// middleware 1 end

middleware1 -> middleware2 -> handler -> middleware2 -> middleware1 の順に実行されていることがわかると思います。
ミドルウェアが増えたときに、チェーンするコードが長くなってくるため以下のような合成の関数を書くこともあります。https://github.com/justinas/alice を使ってもよいでしょう。

type middlewareFunc func(http.Handler) http.Handler

func apply(middlewares ...middlewareFunc) middlewareFunc {
	return middlewareFunc(func(h http.Handler) http.Handler {
		for i := len(middlewares) - 1; i >= 0; i-- {
			h = middlewares[i](h)
		}
		return h
	})
}
func main() {
	http.Handle("/", apply(middleware1, middleware2, middleware3)(http.HandlerFunc(handler)))
	http.ListenAndServe(":8080", nil)
}

middlewaresの末尾からループを実行しているのが奇妙に見えますが、これによりmiddlewaresの先頭から順に適用されることになります。

これでnet/httpのミドルウェア実装についてのおさらいはできました。
ここからはLambdaのコードにおいてこれと同様のことをする方法について紹介していきます。

Lambdaハンドラーの例

次のように複数のLambdaハンドラーがあるとして、素直に実装すると次のようになります。

// hello/main.go

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, err error) {
	// panic時の処理
	defer func() {
		if r := recover(); r != nil {
			response = events.APIGatewayProxyResponse{StatusCode: 500, Body: `panic`}
			err = nil
			return
		}
	}()

	// 入力値のバインディング
	input := ParseInput[Param](request)
	// validation
	if err := Validate(input); err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 400, Body: err.Error()}, nil
	}

	// ビジネスロジック
	result, err := doHello(ctx, input)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}, nil
	}

	// JSON返却(ステータスコード200)
    response = JsonResponse(200, result)
	// ロギング
	Log(ctx, request, response)
	return response, nil
}

func main() {
	lambda.Start(handler)
}
// bye/main.go

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, err error) {
	// panic時の処理
	defer func() {
		if r := recover(); r != nil {
			response = events.APIGatewayProxyResponse{StatusCode: 500, Body: `panic`}
			err = nil
			return
		}
	}()

	// 入力値のバインディング
	input := ParseInput[Param](request)
	// validation
	if err := Validate(input); err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 400, Body: err.Error()}, nil
	}

	// ビジネスロジック
	result, err := doBye(ctx, input)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}, nil
	}

	// JSON返却(ステータスコード200)
    response = JsonResponse(200, result)
	// ロギング
	Log(ctx, request, response)
	return response, nil
}

func main() {
	lambda.Start(handler)
}

差分は doHello(ctx, input) と doBye(ctx, input) の部分だけです。

  • どの部分が重要なのかが一見してわかりにくい

  • ハンドラーを追加するときに漏れなくコピーしてビジネスロジックの部分だけを書き換える必要がある

  • 共通処理を追加したいとなったら全ハンドラーのコードに間違えないように差し込まないといけない

と、今後のメンテナンスが大変になることが容易に想像できると思います。

ここにミドルウェアを導入することで、改善していきたいと思います。

ミドルウェア適用後の全体像

最終的な全体像は以下のようになります。

// hello/main.go

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, err error) {
	input := ctx.Value("ctxKeyParam").(helloParam)

	// ビジネスロジック
	result, err := doHello(ctx, input)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}, nil
	}

	// JSON返却
	return middleware.JsonResponse(200, result), nil
}
func main() {
	m := middleware.NewMiddleware(middleware.DefaultMiddlewares[helloParam]()...)
	lambda.Start(m.Apply(handler))
}
// bye/main.go

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, err error) {
	input := ctx.Value("ctxKeyParam").(byeParam)

	// ビジネスロジック
	result, err := doBye(ctx, input)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}, nil
	}

	// JSON返却
	return middleware.JsonResponse(200, result), nil
}

func main() {
	m := middleware.NewMiddleware(middleware.DefaultMiddlewares[byeParam]()...)
	lambda.Start(m.Apply(handler))
}
// middleware/middlewares.go

package middleware

import (
	"context"
	"encoding/json"
	"fmt"
	"io"

	"github.com/aws/aws-lambda-go/events"
)

type LambdaHandlerFunc func(context.Context, events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error)
type LambdaMiddlewareFunc func(next LambdaHandlerFunc) LambdaHandlerFunc

type Middleware struct {
	middlewares []LambdaMiddlewareFunc
}

func (m *Middleware) Use(middlewares ...LambdaMiddlewareFunc) {
	m.middlewares = append(m.middlewares, middlewares...)
}
func (m *Middleware) Apply(handler LambdaHandlerFunc) LambdaHandlerFunc {
	for i := len(m.middlewares) - 1; i >= 0; i-- {
		handler = m.middlewares[i](handler)
	}
	return handler
}
func DefaultMiddlewares[T any]() []LambdaMiddlewareFunc {
	return []LambdaMiddlewareFunc{
		Recover(),
		ParseInput[T](),
		Log(),
	}

}
func NewMiddleware(middlewares ...LambdaMiddlewareFunc) *Middleware {
	return &Middleware{
		middlewares: middlewares,
	}
}

// Recover panicからの回復処理をするミドルウェア
func Recover() LambdaMiddlewareFunc {
	return func(next LambdaHandlerFunc) LambdaHandlerFunc {
		return func(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, returnErr error) {
			defer func() {
				if r := recover(); r != nil {
					err, ok := r.(error)
					if !ok {
						err = fmt.Errorf("%v", r)
					}
					response = events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}
				}
			}()

			return next(ctx, request)
		}
	}
}

// ParseInput リクエストをstructにparseするミドルウェア
func ParseInput[T any]() LambdaMiddlewareFunc {
	return func(next LambdaHandlerFunc) LambdaHandlerFunc {
		return func(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, returnErr error) {
			var t T

			pr, pw := io.Pipe()
			go func() {
				if err := json.NewEncoder(pw).Encode(request.QueryStringParameters); err != nil {
					response = events.APIGatewayProxyResponse{StatusCode: 400, Body: err.Error()}
				}
			}()
			if err := json.NewDecoder(pr).Decode(&t); err != nil {
				response = events.APIGatewayProxyResponse{StatusCode: 400, Body: err.Error()}
				return
			}

			// handlerへの受け渡しはcontextで行う
			ctx = context.WithValue(ctx, "ctxKeyParam", t)

			return next(ctx, request)
		}
	}
}

// Log リクエストログを出力するミドルウェア
func Log() LambdaMiddlewareFunc {
	return func(next LambdaHandlerFunc) LambdaHandlerFunc {
		return func(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
			res, err := next(ctx, request)
			fmt.Printf("req: %+v, res: %+v\n", request, res)
			return res, err
		}
	}
}

// JsonResponse JSONレスポンスを返却する
func JsonResponse(status int, body any) events.APIGatewayProxyResponse {
	b, err := json.Marshal(body)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500}
	}
	return events.APIGatewayProxyResponse{StatusCode: status, Body: string(b), Headers: map[string]string{"Content-Type": "application/json"}}
}

いかがでしょうか。 middleware パッケージに共通処理が集約され、ハンドラーの処理がすっきりして見通しがよくなったことが見て取れると思います。

ここから細かく見ていきます。

解説

型定義

type LambdaHandlerFunc func(context.Context, events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error)
type LambdaMiddlewareFunc func(next LambdaHandlerFunc) LambdaHandlerFunc

Lambdaのハンドラーおよびミドルウェアの型定義です。
LambdaMiddlewareFunc は、次に実行する LambdaHandlerFunc を受け取り、LambdaHandlerFunc でラップして返す型になっています。

これは、アイデアとしてはEchoのMiddlewareから拝借しています。
開発者はこの型を満たすようにミドルウェア関数を書けばよいわけです。
https://github.com/labstack/echo/blob/a2e7085094bda23a674c887f0e93f4a15245c439/echo.go#L123-L127

ミドルウェアの登録・適用

func (m *Middleware) Use(middlewares ...LambdaMiddlewareFunc) {
	m.middlewares = append(m.middlewares, middlewares...)
}
func (m *Middleware) Apply(handler LambdaHandlerFunc) LambdaHandlerFunc {
	for i := len(m.middlewares) - 1; i >= 0; i-- {
		handler = m.middlewares[i](handler)
	}
	return handler
}
func DefaultMiddlewares[T any]() []LambdaMiddlewareFunc {
	return []LambdaMiddlewareFunc{
		Recover(),
		ParseInput[T](),
		Log(),
	}

}
func NewMiddleware(middlewares ...LambdaMiddlewareFunc) *Middleware {
	return &Middleware{
		middlewares: middlewares,
	}
}

ミドルウェアの登録および適用をする関数です。
これをhandler実行時にラップすることで、ミドルウェアが適用されます。

func main() {
	m := middleware.NewMiddleware(middleware.DefaultMiddlewares[byeParam]()...)
	lambda.Start(m.Apply(handler))
}

ミドルウェア関数の実体

func Log() LambdaMiddlewareFunc {
	return func(next LambdaHandlerFunc) LambdaHandlerFunc {
		return func(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, returnErr error) {
    		// ミドルウェアの処理...

    		// 次の処理を実行
			return next(ctx, request)
		}
	}
}

ミドルウェアを作成する関数です。
ネストが深くてややこしく見え、下のような定義ではだめなのか?と思うかもしれません。

func Log(next LambdaHandlerFunc) LambdaHandlerFunc {
	return func(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
	}
}

これは、例えば以下のように設定値を外から渡すことを考慮しています。

func Log(debug bool) LambdaMiddlewareFunc {
	return func(next LambdaHandlerFunc) LambdaHandlerFunc {
		return func(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
			if !debug {
				return next(ctx, request)
			}
			res, err := next(ctx, request)
			fmt.Printf("req: %+v, res: %+v\n", request, res)
			return res, err
		}
	}
}

ハンドラー関数

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (response events.APIGatewayProxyResponse, err error) {
	input := ctx.Value("ctxKeyParam").(byeParam)

	// ビジネスロジック
	result, err := doBye(ctx, input)
	if err != nil {
		return events.APIGatewayProxyResponse{StatusCode: 500, Body: err.Error()}, nil
	}

	// JSON返却
	return middleware.JsonResponse(200, result), nil
}

ParseInputでctxにセットしたバインディング済みのパラメータを、 ctx.Value("ctxKeyParam") で取得しています。
※なお、contextのkeyにstringを直接使用するのは非推奨なため、実際には独自型を定義すべきです。型アサーションのチェックもしたほうが良いでしょう。
ビジネスロジック以外の処理がミドルウェアに移ったことで、記述量が減り見通しがよくなっています。

まとめ

Lambdaのハンドラーに対して、net/httpやEchoといったWebフレームワークと同様にミドルウェアを適用する方法を紹介しました。
追加のライブラリなしで手軽にミドルウェアを実装することができました。

ミドルウェアの導入により、ハンドラーから非機能的な処理を分離することができ、開発者がビジネスロジックの実装に集中できるようになります。また全体的な変更の適用も容易になります。

今回はAPI Gateway向けのハンドラーの例を紹介しましたが、 型を変更すればSQSなどにも応用することができます。
複数のLambdaハンドラーの共通処理を書く際にこちらが参考になれば幸いです。