見出し画像

TWLogAIAN:異常ログを機械学習で検知する処理の速度改善が楽しくなってきた

今朝は4時半から開発開始です。昨日作っていた異常ログを機械学習で検知する処理が一通りできたので試してみました。処理の流れとしては、

  1. ログデータを読み込み、タイムスタンプの取得、特徴量ベクタの計算後にキーバリューストアにログと特徴量ベクタを保存する

  2. キーバリューストアから特徴量ベクタを読み出し間引きしてIsolation Forestのための学習データを作成する

  3. Isolation Forestの学習する

  4. キーバリューストアからログの特徴量ベクターを取得して異常スコアを計算する(全データ)

という感じです。貧弱なPCでもメモリ使用量を抑えて実行可能にするため計算したデータはキーバリューストアに保存しています。
最初に使ったキーバリューストアは、TWSNMP FCなどで使っているbbolt

です。最近使っている1000万件のアクセスログでトータル18分ぐらいでした。1の読み込みが6分半、2の学習データ作成が2分44秒、4のテストが4分半でした。
10分ぐらいにしたいという目標でモチベーションが上がってきました。
特徴量の計算を並列処理にしてみるなどを試してみましたが、悪化しました。そこで、最近見つけた高速なキーバリューストアBadger

https://github.com/dgraph-io/badger

を試してみることにしました。

の記事によるとbboltより10倍ぐらいデータの書き込みが速いらしいので期待して組み込んでみました。処理の書き方がbboltと似ているので比較的簡単に変更できました。
試してみると、

bboltとbadgerの性能比較

のような感じです。読み込みは3倍、トレーニングデータの作成は8倍に速度アップしています。しかし何故か異常スコアを計算するテストが極端に遅くなっています。この処理がbboltと同じぐらいにできれば、10分以内は達成できそうです。たぶん、キーバリューストア(Badger)から特徴量を読み出しとスコアの書き込みが並列で実行されるのが問題のような気がしています。
詳しく調べたいところですが、今朝は時間切れです。
明日に続く

bbolt版のテストプログラムは、

package main

import (
	"archive/zip"
	"bufio"
	"bytes"
	"encoding/binary"
	"fmt"
	"log"
	"math/rand"
	"strconv"
	"strings"
	"sync"
	"time"

	go_iforest "github.com/codegaudi/go-iforest"
	"github.com/gravwell/gravwell/v3/timegrinder"
	"go.etcd.io/bbolt"
)

var total = 0
var valid = 0
var inputData = [][]float64{}

var mu = &sync.Mutex{}

func main() {
	log.Println("start")
	if err := startDB(); err != nil {
		log.Fatalln(err)
	}
	if err := initTimegrinder(); err != nil {
		log.Fatalln(err)
	}
	loadData()
	for !vectorChDone {
		time.Sleep(time.Second)
	}
	makeInputData()
	makeIForest()
	checkData()
	for !scoreChDone {
		time.Sleep(time.Second)
	}
	closeDB()
	log.Println("end")
}

var tg *timegrinder.TimeGrinder

func initTimegrinder() error {
	var err error
	tg, err = timegrinder.New(timegrinder.Config{
		EnableLeftMostSeed: true,
	})
	if err != nil {
		return err
	}
	tg.SetLocalTime()
	return nil
}

type logEnt struct {
	Key int64
	Val *string
}

type vectorEnt struct {
	Key int64
	Val []float64
}

type scoreEnt struct {
	Key int64
	Val float64
}

var db *bbolt.DB
var logCh = make(chan *logEnt, 1000000)
var vectorCh = make(chan *vectorEnt, 100000)
var scoreCh = make(chan *scoreEnt, 100000)

func startDB() error {
	log.Println("start openDB")
	var err error
	db, err = bbolt.Open("log.db", 0600, nil)
	if err != nil {
		return err
	}
	buckets := []string{"logs", "vectors", "scores"}
	err = db.Update(func(tx *bbolt.Tx) error {
		for _, b := range buckets {
			_, err := tx.CreateBucketIfNotExists([]byte(b))
			if err != nil {
				return err
			}
		}
		return nil
	})
	if err != nil {
		return err
	}
	go logChProcess()
	go vectorChProcess()
	go scoreChProcess()
	return nil
}

func closeDB() {
	if db != nil {
		db.Close()
	}
}

var logChDone = false

func logChProcess() {
	logList := []*logEnt{}
	for e := range logCh {
		logList = append(logList, e)
		if len(logList) > 100000 {
			saveLogList(logList)
			logList = []*logEnt{}
		}
	}
	if len(logList) > 0 {
		saveLogList(logList)
	}
	logChDone = true
}

func saveLogList(list []*logEnt) {
	db.Batch(func(tx *bbolt.Tx) error {
		b := tx.Bucket([]byte("logs"))
		for _, e := range list {
			if err := b.Put([]byte(fmt.Sprintf("%016x", e.Key)), []byte(*e.Val)); err != nil {
				return err
			}
		}
		return nil
	})
}

var vectorChDone = false

func vectorChProcess() {
	vectorList := []*vectorEnt{}
	for e := range vectorCh {
		vectorList = append(vectorList, e)
		if len(vectorList) > 10000 {
			saveVectorList(vectorList)
			vectorList = []*vectorEnt{}
		}
	}
	if len(vectorList) > 0 {
		saveVectorList(vectorList)
	}
	vectorChDone = true
}

func saveVectorList(list []*vectorEnt) {
	db.Batch(func(tx *bbolt.Tx) error {
		b := tx.Bucket([]byte("vectors"))
		for _, e := range list {
			buf := new(bytes.Buffer)
			for _, v := range e.Val {
				err := binary.Write(buf, binary.LittleEndian, v)
				if err != nil {
					log.Println("binary.Write failed:", err)
					return err
				}
			}
			if err := b.Put([]byte(fmt.Sprintf("%016x", e.Key)), buf.Bytes()); err != nil {
				return err
			}
		}
		return nil
	})
}

var scoreChDone = false

func scoreChProcess() {
	scoreList := []*scoreEnt{}
	for e := range scoreCh {
		scoreList = append(scoreList, e)
		if len(scoreList) > 10000 {
			saveScoreList(scoreList)
			scoreList = []*scoreEnt{}
		}
	}
	if len(scoreList) > 0 {
		saveScoreList(scoreList)
	}
	scoreChDone = true
}

func saveScoreList(list []*scoreEnt) {
	db.Batch(func(tx *bbolt.Tx) error {
		b := tx.Bucket([]byte("scores"))
		for _, e := range list {
			buf := new(bytes.Buffer)
			err := binary.Write(buf, binary.LittleEndian, e.Val)
			if err != nil {
				fmt.Println("binary.Write failed:", err)
			}
			if err := b.Put([]byte(fmt.Sprintf("%016x", e.Key)), buf.Bytes()); err != nil {
				return err
			}
		}
		return nil
	})
}

func loadData() {
	log.Println("start loadData")
	st := time.Now()
	r, err := zip.OpenReader("../access.log.zip")
	if err != nil {
		log.Fatal(err)
	}
	defer r.Close()
	for _, f := range r.File {
		log.Printf("log file=%s", f.Name)
		file, err := f.Open()
		if err != nil {
			log.Fatal(err)
		}
		defer file.Close()
		scanner := bufio.NewScanner(file)
		for scanner.Scan() {
			l := scanner.Text()
			total++
			if total%1000000 == 0 {
				log.Printf("loadData total=%d valid=%d dur=%s", total, valid, time.Since(st))
			}
			lineProcess(&l)
		}
		if err := scanner.Err(); err != nil {
			log.Fatal(err)
		}
	}
	log.Printf("end loadData total=%d valid=%d dur=%s", total, valid, time.Since(st))
	close(logCh)
	close(vectorCh)
}

func makeInputData() {
	log.Println("start makeInputdata")
	st := time.Now()
	skip := total / 1000000
	if skip < 1 {
		skip = 1
	}
	db.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket([]byte("vectors"))
		i := 0
		c := b.Cursor()
		for k, v := c.First(); k != nil; k, v = c.Next() {
			i++
			if i%1000000 == 0 {
				log.Printf("makeInputData i=%d len=%d", i, len(inputData))
			}
			if i%skip != 0 {
				continue
			}
			vector := []float64{}
			buf := bytes.NewReader(v)
			for {
				var f float64
				err := binary.Read(buf, binary.LittleEndian, &f)
				if err != nil {
					break
				}
				vector = append(vector, f)
			}
			inputData = append(inputData, vector)

		}
		return nil
	})
	log.Printf("end makeInputdata skip=%d input=%d dur=%s", skip, len(inputData), time.Since(st))

}

var iforest *go_iforest.IForest

func makeIForest() {
	log.Println("makeIForest start")
	st := time.Now()
	rand.Seed(time.Now().UnixNano())
	var err error
	iforest, err = go_iforest.NewIForest(inputData, 1000, 256)
	if err != nil {
		log.Fatal(err)
	}
	log.Printf("makeIForest end dur=%s", time.Since(st))
}

func checkData() {
	log.Println("checkData start")
	st := time.Now()
	db.View(func(tx *bbolt.Tx) error {
		b := tx.Bucket([]byte("vectors"))
		i := 0
		c := b.Cursor()
		for k, v := c.First(); k != nil; k, v = c.Next() {
			i++
			if i%1000000 == 0 {
				log.Printf("checkData i=%d", i)
			}
			vector := []float64{}
			buf := bytes.NewReader(v)
			for {
				var f float64
				err := binary.Read(buf, binary.LittleEndian, &f)
				if err != nil {
					break
				}
				vector = append(vector, f)
			}
			a := strings.SplitN(string(k), ":", 2)
			if len(a) != 2 {
				continue
			}
			key, err := strconv.ParseInt(a[1], 16, 64)
			if err != nil {
				continue
			}
			scoreCh <- &scoreEnt{
				Key: key,
				Val: iforest.CalculateAnomalyScore(vector),
			}
		}
		return nil
	})
	close(scoreCh)
	log.Printf("checkData end dur=%s", time.Since(st))
}

func lineProcess(s *string) {
	ts, ok, err := tg.Extract([]byte(*s))
	if err != nil || !ok {
		return
	}
	v := toVector(s)
	if len(v) > 1 {
		valid++
		k := ts.UnixNano() + int64((total % 100000))
		logCh <- &logEnt{
			Key: k,
			Val: s,
		}
		vectorCh <- &vectorEnt{
			Key: k,
			Val: v,
		}
	}
}

func toVector(s *string) []float64 {
	vector := []float64{}
	a := strings.Split(*s, "\"")
	if len(a) < 2 {
		return vector
	}
	f := strings.Fields(a[1])
	if len(f) < 3 {
		return vector
	}
	query := ""
	ua := strings.SplitN(f[1], "?", 2)
	path := ua[0]
	if len(ua) > 1 {
		query = ua[1]
	}
	ca := getCharCount(a[1])

	//findex_%
	vector = append(vector, float64(strings.Index(a[1], "%")))

	//findex_:
	vector = append(vector, float64(strings.Index(a[1], ":")))

	// countedCharArray
	for _, c := range []rune{':', '(', ';', '%', '/', '\'', '<', '?', '.', '#'} {
		vector = append(vector, float64(ca[c]))
	}

	//encoded =
	vector = append(vector, float64(strings.Count(a[1], "%3D")+strings.Count(a[1], "%3d")))

	//encoded /
	vector = append(vector, float64(strings.Count(a[1], "%2F")+strings.Count(a[1], "%2f")))

	//encoded \
	vector = append(vector, float64(strings.Count(a[1], "%5C")+strings.Count(a[1], "%5c")))

	//encoded %
	vector = append(vector, float64(strings.Count(a[1], "%25")))

	//%20
	vector = append(vector, float64(strings.Count(a[1], "%20")))

	//POST
	if strings.HasPrefix(a[1], "POST") {
		vector = append(vector, 1)
	} else {
		vector = append(vector, 0)
	}

	//path_nonalnum_count
	vector = append(vector, float64(len(path)-getAlphaNumCount(path)))

	//pvalue_nonalnum_avg
	vector = append(vector, float64(len(query)-getAlphaNumCount(query)))

	//non_alnum_len(max_len)
	vector = append(vector, float64(getMaxNonAlnumLength(a[1])))

	//non_alnum_count
	vector = append(vector, float64(getNonAlnumCount(a[1])))

	for _, p := range []string{"/%", "//", "/.", "..", "=/", "./", "/?"} {
		vector = append(vector, float64(strings.Count(a[1], p)))
	}
	return vector
}

func getCharCount(s string) []int {
	ret := []int{}
	for i := 0; i < 96; i++ {
		ret = append(ret, 0)
	}
	for _, c := range s {
		if 33 <= c && c <= 95 {
			ret[c] += 1
		}
	}
	return ret
}

func getAlphaNumCount(s string) int {
	ret := 0
	for _, c := range s {
		if 65 <= c && c <= 90 {
			ret++
		} else if 97 <= c && c <= 122 {
			ret++
		} else if 48 <= c && c <= 57 {
			ret++
		}
	}
	return ret
}

func getMaxNonAlnumLength(s string) int {
	max := 0
	length := 0
	for _, c := range s {
		if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') {
			if length > max {
				max = length
			}
			length = 0
		} else {
			length++
		}
	}
	if max < length {
		max = length
	}
	return max
}

func getNonAlnumCount(s string) int {
	ret := 0
	for _, c := range s {
		if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') {
		} else {
			ret++
		}
	}
	return ret
}

badger版は(共通の関数は省略しています)

package main

import (
	"archive/zip"
	"bufio"
	"bytes"
	"encoding/binary"
	"fmt"
	"log"
	"math/rand"
	"strconv"
	"strings"
	"sync"
	"time"

	go_iforest "github.com/codegaudi/go-iforest"
	badger "github.com/dgraph-io/badger/v3"
	"github.com/gravwell/gravwell/v3/timegrinder"
)

var total = 0
var valid = 0
var inputData = [][]float64{}

var mu = &sync.Mutex{}

func main() {
	log.Println("start")
	if err := startDB(); err != nil {
		log.Fatalln(err)
	}
	if err := initTimegrinder(); err != nil {
		log.Fatalln(err)
	}
	loadData()
	for !vectorChDone {
		time.Sleep(time.Second)
	}
	makeInputData()
	makeIForest()
	checkData()
	for !scoreChDone {
		time.Sleep(time.Second)
	}
	closeDB()
	log.Println("end")
}

var tg *timegrinder.TimeGrinder

func initTimegrinder() error {
	var err error
	tg, err = timegrinder.New(timegrinder.Config{
		EnableLeftMostSeed: true,
	})
	if err != nil {
		return err
	}
	tg.SetLocalTime()
	return nil
}

type logEnt struct {
	Key int64
	Val *string
}

type vectorEnt struct {
	Key int64
	Val []float64
}

type scoreEnt struct {
	Key int64
	Val float64
}

var db *badger.DB
var logCh = make(chan *logEnt, 100000)
var vectorCh = make(chan *vectorEnt, 100000)
var scoreCh = make(chan *scoreEnt, 100000)

func startDB() error {
	log.Println("start openDB")
	var err error
	db, err = badger.Open(badger.DefaultOptions("./badger"))
	if err != nil {
		return err
	}
	go logChProcess()
	go vectorChProcess()
	go scoreChProcess()
	return nil
}

func closeDB() {
	if db != nil {
		db.Close()
	}
}

var logChDone = false

func logChProcess() {
	logList := []*logEnt{}
	for e := range logCh {
		logList = append(logList, e)
		if len(logList) > 100000 {
			saveLogList(logList)
			logList = []*logEnt{}
		}
	}
	if len(logList) > 0 {
		saveLogList(logList)
	}
	logChDone = true
}

func saveLogList(list []*logEnt) {
	db.Update(func(txn *badger.Txn) error {
		for _, e := range list {
			if err := txn.Set([]byte(fmt.Sprintf("l:%016x", e.Key)), []byte(*e.Val)); err != nil {
				return err
			}
		}
		return nil
	})
}

var vectorChDone = false

func vectorChProcess() {
	vectorList := []*vectorEnt{}
	for e := range vectorCh {
		vectorList = append(vectorList, e)
		if len(vectorList) > 10000 {
			saveVectorList(vectorList)
			vectorList = []*vectorEnt{}
		}
	}
	if len(vectorList) > 0 {
		saveVectorList(vectorList)
	}
	vectorChDone = true
}

func saveVectorList(list []*vectorEnt) {
	db.Update(func(txn *badger.Txn) error {
		for _, e := range list {
			buf := new(bytes.Buffer)
			for _, v := range e.Val {
				err := binary.Write(buf, binary.LittleEndian, v)
				if err != nil {
					log.Println("binary.Write failed:", err)
					return err
				}
			}
			if err := txn.Set([]byte(fmt.Sprintf("v:%016x", e.Key)), buf.Bytes()); err != nil {
				return err
			}
		}
		return nil
	})
}

var scoreChDone = false

func scoreChProcess() {
	scoreList := []*scoreEnt{}
	for e := range scoreCh {
		scoreList = append(scoreList, e)
		if len(scoreList) > 10000 {
			saveScoreList(scoreList)
			scoreList = []*scoreEnt{}
		}
	}
	if len(scoreList) > 0 {
		saveScoreList(scoreList)
	}
	scoreChDone = true
}

func saveScoreList(list []*scoreEnt) {
	db.Update(func(txn *badger.Txn) error {
		for _, e := range list {
			buf := new(bytes.Buffer)
			err := binary.Write(buf, binary.LittleEndian, e.Val)
			if err != nil {
				fmt.Println("binary.Write failed:", err)
			}
			if err := txn.Set([]byte(fmt.Sprintf("s:%016x", e.Key)), buf.Bytes()); err != nil {
				return err
			}
		}
		return nil
	})
}

func loadData() {
	log.Println("start loadData")
	st := time.Now()
	r, err := zip.OpenReader("../access.log.zip")
	if err != nil {
		log.Fatal(err)
	}
	defer r.Close()
	for _, f := range r.File {
		log.Printf("log file=%s", f.Name)
		file, err := f.Open()
		if err != nil {
			log.Fatal(err)
		}
		defer file.Close()
		scanner := bufio.NewScanner(file)
		for scanner.Scan() {
			l := scanner.Text()
			total++
			if total%1000000 == 0 {
				log.Printf("loadData total=%d valid=%d dur=%s", total, valid, time.Since(st))
			}
			lineProcess(&l)
		}
		if err := scanner.Err(); err != nil {
			log.Fatal(err)
		}
	}
	log.Printf("end loadData total=%d valid=%d dur=%s", total, valid, time.Since(st))
	close(logCh)
	close(vectorCh)
}

func makeInputData() {
	log.Println("start makeInputdata")
	st := time.Now()
	skip := total / 1000000
	if skip < 1 {
		skip = 1
	}
	db.View(func(txn *badger.Txn) error {
		it := txn.NewIterator(badger.DefaultIteratorOptions)
		defer it.Close()
		i := 0
		prefix := []byte("v:")
		for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
			i++
			if i%1000000 == 0 {
				log.Printf("makeInputData i=%d len=%d", i, len(inputData))
			}
			if i%skip != 0 {
				continue
			}
			item := it.Item()
			err := item.Value(func(v []byte) error {
				vector := []float64{}
				buf := bytes.NewReader(v)
				for {
					var f float64
					err := binary.Read(buf, binary.LittleEndian, &f)
					if err != nil {
						break
					}
					vector = append(vector, f)
				}
				inputData = append(inputData, vector)
				return nil
			})
			if err != nil {
				return err
			}
		}
		return nil
	})
	log.Printf("end makeInputdata skip=%d input=%d dur=%s", skip, len(inputData), time.Since(st))
}

var iforest *go_iforest.IForest

func makeIForest() {
	log.Println("makeIForest start")
	st := time.Now()
	rand.Seed(time.Now().UnixNano())
	var err error
	iforest, err = go_iforest.NewIForest(inputData, 1000, 256)
	if err != nil {
		log.Fatal(err)
	}
	log.Printf("makeIForest end dur=%s", time.Since(st))
}

func checkData() {
	log.Println("checkData start")
	st := time.Now()
	db.View(func(txn *badger.Txn) error {
		it := txn.NewIterator(badger.DefaultIteratorOptions)
		defer it.Close()
		i := 0
		prefix := []byte("v:")
		for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() {
			i++
			if i%1000000 == 0 {
				log.Printf("checkData i=%d", i)
			}
			item := it.Item()
			k := item.Key()
			a := strings.SplitN(string(k), ":", 2)
			if len(a) != 2 {
				continue
			}
			key, err := strconv.ParseInt(a[1], 16, 64)
			if err != nil {
				continue
			}
			err = item.Value(func(v []byte) error {
				vector := []float64{}
				buf := bytes.NewReader(v)
				for {
					var f float64
					err := binary.Read(buf, binary.LittleEndian, &f)
					if err != nil {
						break
					}
					vector = append(vector, f)
				}
				scoreCh <- &scoreEnt{
					Key: key,
					Val: iforest.CalculateAnomalyScore(vector),
				}
				return nil
			})
			if err != nil {
				return err
			}
		}
		return nil
	})
	close(scoreCh)
	log.Printf("checkData end dur=%s", time.Since(st))
}

func lineProcess(s *string) {
	ts, ok, err := tg.Extract([]byte(*s))
	if err != nil || !ok {
		return
	}
	v := toVector(s)
	if len(v) > 1 {
		valid++
		k := ts.UnixNano() + int64((total % 100000))
		logCh <- &logEnt{
			Key: k,
			Val: s,
		}
		vectorCh <- &vectorEnt{
			Key: k,
			Val: v,
		}
	}
}



開発のための諸経費(機材、Appleの開発者、サーバー運用)に利用します。 ソフトウェアのマニュアルをnoteの記事で提供しています。 サポートによりnoteの運営にも貢献できるのでよろしくお願います。