机器学习常见的sampling策略 附PyTorch实现

简单的采样策略

首先介绍三种简单采样策略:

  1. Instance-balanced sampling, 实例平衡采样。
  2. Class-balanced sampling, 类平衡采样。
  3. Square-root sampling, 平方根采样。

它们可抽象为:

[p_j=frac{n_j^q}{sum_{i=1}^Cn_i^q}, ]

(p_j)表示从j类采样数据的概率;(C)表示类别数量;(n_j)表示j类样本数;(qin{1,0,frac{1}{2}})
Instance-balanced sampling
最常见的数据采样方式,其中每个训练样本被选择的概率相等((q=1))。j类被采样的概率(p^{mathbf{IB}}_j)与j类样本数(n_j)成正比,即(p^{mathbf{IB}}_j=frac{n_j}{sum_{i=1}^Cn_i})

Class-balanced sampling
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:(p^{mathbf{CB}}_j=frac{1}{C})。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。
Square-root sampling
平方根采样最常见的变体,(q=frac{1}{2})

由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:

import numpy as np from torch.utils.data.sampler import WeightedRandomSampler def get_sampler(sampling_type, targets):     cls_counts = np.bincount(targets)     if sampling_type == 'instance-balanced':         cls_weights = cls_counts / np.sum(cls_counts)              elif sampling_type == 'class-balanced':         cls_num = len(cls_counts)         cls_weights = [1. / cls_num] * cls_num              elif sampling_type == 'square-root':         sqrt_and_sum = np.sum([num**0.5 for num in cls_counts])         cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts]     else:         raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root')          cls_weights = np.array(cls_weights)     return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True) 

WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

在模拟的长尾数据集测试下:

import torch from torch.utils.data import Dataset, DataLoader, Sampler torch.manual_seed(0) np.random.seed(0) class LongTailDataset(Dataset):     def __init__(self, num_classes, max_samples_per_class):         self.num_classes = num_classes         self.max_samples_per_class = max_samples_per_class          # Generate number of samples for each class inversely proportional to class index         self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]         self.total_samples = sum(self.samples_per_class)          # Generate targets for the dataset         self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])      def __len__(self):         return self.total_samples      def __getitem__(self, idx):         # For simplicity, just return the index as the data         return idx, self.targets[idx]  # Parameters num_classes = 25 max_samples_per_class = 1000  # Create dataset dataset = LongTailDataset(num_classes, max_samples_per_class)  # Create sampler batch_size = 128 sampler1 = get_sampler('instance-balanced', dataset.targets.numpy()) sampler2 = get_sampler('class-balanced', dataset.targets.numpy()) sampler3 = get_sampler('square-root', dataset.targets.numpy())  def test_sampler_in_one_batch(sampler:Sampler, inf:str):     print(inf)     for (_, target) in DataLoader(dataset, batch_size=64, sampler=sampler):         cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True)         print(f'Class indices: {cls_idx}')         print(f'Class counts: {cls_counts}')         break # just show one batch     print('-'*20)  samplers = [sampler1, sampler2, sampler3] infs = ['Instance-balanced:', 'Class-balanced:', 'Square-root:'] for sampler, inf in zip(samplers, infs):     test_sampler_in_one_batch(sampler, inf) 

Output:

Instance-balanced: Class indices: [ 0  1  2  3  5 16 22 23] Class counts: [42 10  5  2  2  1  1  1] -------------------- Class-balanced: Class indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 16 17 20 21 23] Class counts: [22  7  6  4  2  1  2  2  3  3  1  2  1  1  1  1  2  1  1  1] -------------------- Square-root: Class indices: [ 0  1  2  3  4  5  6  9 10 21 22 23] Class counts: [37  8  3  6  3  1  1  1  1  1  1  1] -------------------- 

混合采样策略

最早的混合采样是在 (0le epochle t)时采用Instance-balanced采样,(tle epochle T)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)(p_j)也发生变化:

[p_j^{mathbf{PB}}(t)=(1-frac tT)p_j^{mathbf{IB}}+frac tTp_j^{mathbf{CB}} ]

t表示当前epoch,T表示总epoch数。

不平衡数据集下的采样策略

不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即(p_j=frac{1}{n_j})

... elif sampling_type == 'inverse':     cls_weights = 1. / cls_counts ... 

在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:

[E_n=(1-beta^n)/(1-beta), mathrm{where~}beta=(N-1)/N. ]

这里N表示数据集样本总数。

相关代码:

... elif sampling_type == 'effective':     beta = (len(targets) - 1) / len(targets)     cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts)) ... 

Output

Effective: Class indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 16 17 18 20 21 22 23 24] Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1] -------------------- 

在和上面一样的模拟长尾数据集上,采样的结果更加均衡。

参考文献

  1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
  2. torch.utils.data.WeightedRandomSampler
  3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
发表评论

评论已关闭。

相关文章