1. 对列表和数组进行切片
1.1 切片索引
众所周知,Python中的列表和numpy数组都支持用begin: end
语法来表示[begin, end)
区间的的切片索引:
import numpy as np my_list= [1, 2, 3, 4, 5] print(my_list[2: 4]) # [3, 4] my_arr = np.array([1, 2, 3, 4, 5]) print(my_arr[2: 4]) # [3 4]
以上操作实际上等同于用slice
切片索引对象对其进行切片:
print(my_list[slice(2, 4)]) # [3, 4] print(my_arr[slice(2, 4)]) # [3 4]
numpy数组还支持用列表和numpy数组来表示切片索引,而列表则不支持:
print(my_arr[[2, 3]]) # [3 4] print(my_arr[np.arange(2, 4)]) # [3, 4] print(my_list[[2, 3]]) # TypeError: list indices must be integers or slices, not list print(my_list[np.arange(2, 4)]) # TypeError: only integer scalar arrays can be converted to a scalar index
Pytorch的torch.utils.data.Dataset
数据集支持单元素索引,但不支持切片:
from torchvision.datasets import FashionMNIST from torchvision.transforms import Compose, ToTensor, Normalize transform = Compose( [ToTensor(), Normalize((0.1307,), (0.3081,)) ] ) data = FashionMNIST( root="data", download=True, train=True, transform=transform ) print(data[0], data[1]) # (tensor(...), 0) (tensor(...), 0) print(data[[0, 1]]) # ValueError: only one element tensors can be converted to Python scalars print(data[: 2]) # ValueError: only one element tensors can be converted to Python scalars
要想对torch.utils.data.Dataset
进行切片,需要创建Subset
对象:
import torch indices = [0, 1] # or indices = np.arange(2) data_0to1 = torch.utils.data.Subset(data, indices) print(type(data_0to1)) # <class 'torch.utils.data.dataset.Subset'>
Subset
对象同样支持单元素索引操作且不支持切片:
print(data_0to1[0]) # (tensor(...), 0)
查看Pytorch源码可知,Subset
类的定义实际上是这样的:
class Subset(Dataset[T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices)
从以上代码片段可以清晰地看到Subset
类用indices
来存储本身做为子集的索引集合,然后重写(override)了__getitem__()
方法来实现对子集的单元素索引。
1.2 对切片索引进行命名
有时我们会使用充满硬编码的切片索引,这使得代码难以阅读,比如下面这段代码:
record = ".....100...513.25.." cost = int(record[5: 8]) * float(record[11: 17]) print(cost) # 51325.0
与其这样做,我们不如对切片进行命名:
SHARES = slice(5, 8) PRICE = slice(11, 17) cost = int(record[SHARES]) * float(record[PRICE]) print(cost) # 51325.0
在后一种版本中,由于避免了使用许多神秘难懂的硬编码索引,我们的代码就变得清晰了许多。
正如我们前面所说,这里的slice()
函数会创建一个slice
类型的切片对象,可以用在任何运行切片的地方:
items = [0, 1, 2, 3, 4, 5, 6] a = slice(2, 4) print(items[2: 4]) # [2, 3] print(items[a]) # [2, 3] items[a] = [10, 11] print(items) # [0, 1, 10, 11, 4, 5, 6] del items[a] print(items) # [0, 1, 4, 5, 6]
如果有一个slice
对象的实例s
,可以分别用过s.start
、s.stop
以及s.step
属性来跌倒关于该对象的信息。例如:
a = slice(5, 50, 2) print(a.start, a.stop, a.step) # 5 10 2
此外,可以通过使用indices(size)
方法将切片映射到特定大小的序列上。这会返回一个[start, stop, step)
元组,所有的值都已经恰当地限制在边界以内(当做索引操作时可避免出现IndexError
异常)。例如:
s = 'HelloWorld' print(a.indices(len(s))) print(*a.indices(len(s))) for i in range(*a.indices(len(s))): print(s[i]) # W # r # d
2. 对迭代器做切片操作
要对迭代器和生成器做切片操作,普通的切片操作符在这里是不管用的:
def count(n): while True: yield n n += 1 c = count(0) print(c[10: 20]) # TypeError: 'generator' object is not subscriptable
此时,itertools.islice()
函数是最完美的选择:
import itertools for x in itertools.islice(c, 10, 20): print(x) # 10 # 11 # 12 # 13 # 14 # 15 # 16 # 17 # 18 # 19
注意,迭代器和生成器之所以没法执行普通的切片操作,这是因为不知道它们的长度是多少(而且它们也没有实现索引)。islice()
产生的结果是一个迭代器,它可以产生出所需要的切片元素,但这是通过访问并丢弃起始索引之前的元素来实现的。之后的元素会由islice
对象产生出来,直到到达结束索引为止。
还有一点需要重点强调的是islice()
会消耗掉所提供的的迭代器中数据。由于迭代器中的元素只能访问一次,没法倒回去,因此这里就需要引起我们的注意了。如果之后还需要倒回去访问前面的元素,那也许就应该先将数据转到列表中去。
参考
- [1] https://stackoverflow.com/questions/54251798/pytorch-can-not-slice-torchvision-mnist-dataset
- [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2015.