見出し画像

Fast AutoAugment

Arxiv


Official Pytorch implementation


One of the Authors, Kaggle Competitions Master


1. Abstract

画像1

1-1.   Fast AutoAugment is an algorithm to automatically search for augmentation policies.

1-2.   A GPU workload of Fast AutoAugment is small and its strategy is based on density matching.



2. Introduction and Related Work

2-1.   Baseline Augmentation, Mixup, Cutout, and CutMix
- Designed manually based on domain knowledge.

2-2.   Smart Augmentation

- Learns to generate augmented data by merging two or more samples in the same class.

2-3.   Simulated GAN 


2-4.   AutoAugment
- Uses reinforcement learning (RL).
- Requires thousands of GPU hours.


2-5.   Population Based Augmentation (PBA)
- Based on population-based training method of hyperparameter optimization.

- Directly searches for augmentation policies that maximize the match between the distribution of augmented split and the distribution of another, un-augmented split via a single model. 

2-6.   Bayesian Data Augmentation (BDA)
- New annotated training points are treated as missing variables and generated based on the distribution learned from the training set.

- Generalised Monte Carlo EM algorithm as an extension of the Generative Adversarial Network (GAN).


2-7.   Fast AutoAugment
- Different from BDA, recovers those missing data points by the exploitation-and-exploration via Bayesian optimization in the policy search phase.

- Table 1 below shows, Fast AutoAugment can search augmentation policies significantly faster than AutoAugment, while retaining comparable performances to AutoAugment on diverse image datasets and networks.

画像2


- Uses Tree-structured Parzen Estimator (TPE) algorithm for practical implementation.



3. Fast AutoAugment


3-1.     Search Space

画像3

- Search space (S) consists of sub-policyies (tau1, tau2, ...) and 2 parameters (p: probability,   lambda: magnitude of an operation) of each operation.

- Sub-policy can be any number of consecutive operations. In the case of Figure 1, 2 operations are applied.

- Final policy is a collection of sub-policies.


3-2.     Search Strategy

- Split train data into D_M (for Model) and D_A (for Augmentation) that are used for learning the model parameters (theta) and exploring the augmentation policy (Tau) , respectively.

画像4

- Augmentation policy is derived from optimizing equation (2). R is accuracy of the model ( with parameters (theta^*) trained on D_M ) on augmented D_A.


画像5

- Figure 2 shows the procedure how optimization in equation (2) is performed.

- The procedure is as following,

画像6


<Step 1> 
Perform the K-fold stratified shuffling to split the train dataset into D^(1)_train, ...,  D^(K)_train where each D^(k)_train consists of two datasets (D^(k)_M and D^(k)_A) using StratifiedShuffleSplit method in sklearn.
<Step 2>
Train model parameter theta on D^(k)_M from scratch without data augmentation.
<Step 3>
For each time (t = 0, 1, ..., T-1), explore B (: tau_1, tau_2, ..., tau_B) candidate policies via Bayesian optimization method. Then select top-N policies (= tau_t in line 8 of algorithm 1 above) over B.
<Step 4>
Merge every tau_t into final policy set (= tau_* in line 9)

- Policy Exploration is performed with HyperOpt of Ray using TPE to estimate Expected Improvement.


この記事が気に入ったらサポートをしてみませんか?