您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:pytorch中DataLoader()过程中遇到的一些问题

51自学网 2021-10-30 22:37:04
  python
这篇教程pytorch中DataLoader()过程中遇到的一些问题写得很实用,希望能帮到您。

如下所示:

RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2

train_dataset = datasets.ImageFolder(    traindir,    transforms.Compose([        transforms.Resize((224)) ###

原因是

transforms.Resize() 的参数设置问题,改为如下设置就可以了

train_dataset = datasets.ImageFolder(    traindir,    transforms.Compose([        transforms.Resize((224,224)),

同理,val_dataset中也调整为transforms.Resize((224,224))。

补充:pytorch之dataloader深入剖析

- dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

- 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

- 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存 ​

② Queue的特点

当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。

当数据满了: queue.put() 会阻塞

③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展

输入数据PipeLine

pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象

② 创建一个 DataLoader 对象

③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

dataset = MyDataset()dataloader = DataLoader(dataset)num_epoches = 100for epoch in range(num_epoches):for img, label in dataloader:....

所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

1.DataLoader

先介绍一下DataLoader(object)的参数:

dataset(Dataset): 传入的数据集

batch_size(int, optional): 每个batch有多少个样本

shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序

sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…

如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each

worker subprocess with the worker id (an int in [0, num_workers - 1]) asinput, after seeding and before data loading. (default: None) 

- 首先dataloader初始化时得到datasets的采样list

class DataLoader(object):    r"""    Data loader. Combines a dataset and a sampler, and provides    single- or multi-process iterators over the dataset.    Arguments:        dataset (Dataset): dataset from which to load the data.        batch_size (int, optional): how many samples per batch to load            (default: 1).        shuffle (bool, optional): set to ``True`` to have the data reshuffled            at every epoch (default: False).        sampler (Sampler, optional): defines the strategy to draw samples from            the dataset. If specified, ``shuffle`` must be False.        batch_sampler (Sampler, optional): like sampler, but returns a batch of            indices at a time. Mutually exclusive with batch_size, shuffle,            sampler, and drop_last.        num_workers (int, optional): how many subprocesses to use for data            loading. 0 means that the data will be loaded in the main process.            (default: 0)        collate_fn (callable, optional): merges a list of samples to form a mini-batch.        pin_memory (bool, optional): If ``True``, the data loader will copy tensors            into CUDA pinned memory before returning them.        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,            if the dataset size is not divisible by the batch size. If ``False`` and            the size of dataset is not divisible by the batch size, then the last batch            will be smaller. (default: False)        timeout (numeric, optional): if positive, the timeout value for collecting a batch            from workers. Should always be non-negative. (default: 0)        worker_init_fn (callable, optional): If not None, this will be called on each            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as            input, after seeding and before data loading. (default: None)    .. note:: By default, each worker will have its PyTorch seed set to              ``base_seed + worker_id``, where ``base_seed`` is a long generated              by main process using its RNG. However, seeds for other libraies              may be duplicated upon initializing workers (w.g., NumPy), causing              each worker to return identical random numbers. (See              :ref:`dataloader-workers-random-seed` section in FAQ.) You may              use ``torch.initial_seed()`` to access the PyTorch seed for each              worker in :attr:`worker_init_fn`, and use it to set other seeds              before data loading.    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an                 unpicklable object, e.g., a lambda function.    """    __initialized = False    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,                 timeout=0, worker_init_fn=None):        self.dataset = dataset        self.batch_size = batch_size        self.num_workers = num_workers        self.collate_fn = collate_fn        self.pin_memory = pin_memory        self.drop_last = drop_last        self.timeout = timeout        self.worker_init_fn = worker_init_fn        if timeout < 0:            raise ValueError('timeout option should be non-negative')        if batch_sampler is not None:            if batch_size > 1 or shuffle or sampler is not None or drop_last:                raise ValueError('batch_sampler option is mutually exclusive '                                 'with batch_size, shuffle, sampler, and '                                 'drop_last')            self.batch_size = None            self.drop_last = None        if sampler is not None and shuffle:            raise ValueError('sampler option is mutually exclusive with '                             'shuffle')        if self.num_workers < 0:            raise ValueError('num_workers option cannot be negative; '                             'use num_workers=0 to disable multiprocessing.')        if batch_sampler is None:            if sampler is None:                if shuffle:                    sampler = RandomSampler(dataset)  //将list打乱                else:                    sampler = SequentialSampler(dataset)            batch_sampler = BatchSampler(sampler, batch_size, drop_last)        self.sampler = sampler        self.batch_sampler = batch_sampler        self.__initialized = True    def __setattr__(self, attr, val):        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):            raise ValueError('{} attribute should not be set after {} is '                             'initialized'.format(attr, self.__class__.__name__))        super(DataLoader, self).__setattr__(attr, val)    def __iter__(self):        return _DataLoaderIter(self)    def __len__(self):        return len(self.batch_sampler)

其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!

class RandomSampler(Sampler):    r"""Samples elements randomly, without replacement.    Arguments:        data_source (Dataset): dataset to sample from    """    def __init__(self, data_source):        self.data_source = data_source    def __iter__(self):        return iter(torch.randperm(len(self.data_source)).tolist())    def __len__(self):        return len(self.data_source)
class BatchSampler(Sampler):    r"""Wraps another sampler to yield a mini-batch of indices.    Args:        sampler (Sampler): Base sampler.        batch_size (int): Size of mini-batch.        drop_last (bool): If ``True``, the sampler will drop the last batch if            its size would be less than ``batch_size``    Example:        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]    """    def __init__(self, sampler, batch_size, drop_last):        if not isinstance(sampler, Sampler):            raise ValueError("sampler should be an instance of "                             "torch.utils.data.Sampler, but got sampler={}"                             .format(sampler))        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or /                batch_size <= 0:            raise ValueError("batch_size should be a positive integeral value, "                             "but got batch_size={}".format(batch_size))        if not isinstance(drop_last, bool):            raise ValueError("drop_last should be a boolean value, but got "                             "drop_last={}".format(drop_last))        self.sampler = sampler        self.batch_size = batch_size        self.drop_last = drop_last    def __iter__(self):        batch = []        for idx in self.sampler:            batch.append(idx)            if len(batch) == self.batch_size:                yield batch                batch = []        if len(batch) > 0 and not self.drop_last:            yield batch    def __len__(self):        if self.drop_last:            return len(self.sampler) // self.batch_size        else:            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

- 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;

- 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的

__getitem__()方法

- 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;

class _DataLoaderIter(object):    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""    def __init__(self, loader):        self.dataset = loader.dataset        self.collate_fn = loader.collate_fn        self.batch_sampler = loader.batch_sampler        self.num_workers = loader.num_workers        self.pin_memory = loader.pin_memory and torch.cuda.is_available()        self.timeout = loader.timeout        self.done_event = threading.Event()        self.sample_iter = iter(self.batch_sampler)        base_seed = torch.LongTensor(1).random_().item()        if self.num_workers > 0:            self.worker_init_fn = loader.worker_init_fn            self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]            self.worker_queue_idx = 0            self.worker_result_queue = multiprocessing.SimpleQueue()            self.batches_outstanding = 0            self.worker_pids_set = False            self.shutdown = False            self.send_idx = 0            self.rcvd_idx = 0            self.reorder_dict = {}            self.workers = [                multiprocessing.Process(                    target=_worker_loop,                    args=(self.dataset, self.index_queues[i],                          self.worker_result_queue, self.collate_fn, base_seed + i,                          self.worker_init_fn, i))                for i in range(self.num_workers)]            if self.pin_memory or self.timeout > 0:                self.data_queue = queue.Queue()                if self.pin_memory:                    maybe_device_id = torch.cuda.current_device()                else:                    # do not initialize cuda context if not necessary                    maybe_device_id = None                self.worker_manager_thread = threading.Thread(                    target=_worker_manager_loop,                    args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,                          maybe_device_id))                self.worker_manager_thread.daemon = True                self.worker_manager_thread.start()            else:                self.data_queue = self.worker_result_queue            for w in self.workers:                w.daemon = True  # ensure that the worker exits on process exit                w.start()            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))            _set_SIGCHLD_handler()            self.worker_pids_set = True            # prime the prefetch loop            for _ in range(2 * self.num_workers):                self._put_indices()    def __len__(self):        return len(self.batch_sampler)    def _get_batch(self):        if self.timeout > 0:            try:                return self.data_queue.get(timeout=self.timeout)            except queue.Empty:                raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))        else:            return self.data_queue.get()    def __next__(self):        if self.num_workers == 0:  # same-process loading            indices = next(self.sample_iter)  # may raise StopIteration            batch = self.collate_fn([self.dataset[i] for i in indices])            if self.pin_memory:                batch = pin_memory_batch(batch)            return batch        # check if the next sample has already been generated        if self.rcvd_idx in self.reorder_dict:            batch = self.reorder_dict.pop(self.rcvd_idx)            return self._process_next_batch(batch)        if self.batches_outstanding == 0:            self._shutdown_workers()            raise StopIteration        while True:            assert (not self.shutdown and self.batches_outstanding > 0)            idx, batch = self._get_batch()            self.batches_outstanding -= 1            if idx != self.rcvd_idx:                # store out-of-order samples                self.reorder_dict[idx] = batch                continue            return self._process_next_batch(batch)    next = __next__  # Python 2 compatibility    def __iter__(self):        return self    def _put_indices(self):        assert self.batches_outstanding < 2 * self.num_workers        indices = next(self.sample_iter, None)        if indices is None:            return        self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))        self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers        self.batches_outstanding += 1        self.send_idx += 1    def _process_next_batch(self, batch):        self.rcvd_idx += 1        self._put_indices()        if isinstance(batch, ExceptionWrapper):            raise batch.exc_type(batch.exc_msg)        return batch
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):    global _use_shared_memory    _use_shared_memory = True    # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal    # module's handlers are executed after Python returns from C low-level    # handlers, likely when the same fatal signal happened again already.    # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1    _set_worker_signal_handlers()    torch.set_num_threads(1)    random.seed(seed)    torch.manual_seed(seed)    if init_fn is not None:        init_fn(worker_id)    watchdog = ManagerWatchdog()    while True:        try:            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)        except queue.Empty:            if watchdog.is_alive():                continue            else:                break        if r is None:            break        idx, batch_indices = r        try:            samples = collate_fn([dataset[i] for i in batch_indices])        except Exception:            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))        else:            data_queue.put((idx, samples))            del samples

- 需要对队列操作,缓存数据,使得加载提速!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。


python 爬取影视网站下载链接
Django分页器的用法详解
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。