背景介绍

OpenMMLab项目中构建数据集的基础类BaseDataset类的时候,对多进程数据加载有一个优化,叫做 ‘‘’序列化’。

先看代码部分

class BaseDataset(Dataset):

r"""BaseDataset for open source projects in OpenMMLab.

The annotation format is shown as follows.

.. code-block:: none

{

"metainfo":

{

"dataset_type": "test_dataset",

"task_name": "test_task"

},

"data_list":

[

{

"img_path": "test_img.jpg",

"height": 604,

"width": 640,

"instances":

[

{

"bbox": [0, 0, 10, 20],

"bbox_label": 1,

"mask": [[0,0],[0,10],[10,20],[20,0]],

"extra_anns": [1,2,3]

},

{

"bbox": [10, 10, 110, 120],

"bbox_label": 2,

"mask": [[10,10],[10,110],[110,120],[120,10]],

"extra_anns": [4,5,6]

}

]

},

]

}

Args:

ann_file (str, optional): Annotation file path. Defaults to ''.

metainfo (Mapping or Config, optional): Meta information for

dataset, such as class information. Defaults to None.

data_root (str, optional): The root directory for ``data_prefix`` and

``ann_file``. Defaults to ''.

data_prefix (dict): Prefix for training data. Defaults to

dict(img_path='').

filter_cfg (dict, optional): Config for filter data. Defaults to None.

indices (int or Sequence[int], optional): Support using first few

data in annotation file to facilitate training/testing on a smaller

serialize_data (bool, optional): Whether to hold memory using

serialized objects, when enabled, data loader workers can use

shared RAM from master process instead of making a copy. Defaults

to True.

pipeline (list, optional): Processing pipeline. Defaults to [].

test_mode (bool, optional): ``test_mode=True`` means in test phase.

Defaults to False.

lazy_init (bool, optional): Whether to load annotation during

instantiation. In some cases, such as visualization, only the meta

information of the dataset is needed, which is not necessary to

load annotation file. ``Basedataset`` can skip load annotations to

save time by set ``lazy_init=True``. Defaults to False.

max_refetch (int, optional): If ``Basedataset.prepare_data`` get a

None img. The maximum extra number of cycles to get a valid

image. Defaults to 1000.

Note:

BaseDataset collects meta information from ``annotation file`` (the

lowest priority), ``BaseDataset.METAINFO``(medium) and ``metainfo

parameter`` (highest) passed to constructors. The lower priority meta

information will be overwritten by higher one.

Note:

Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc.

should not inherit from ``BaseDataset`` since ``get_subset`` and

``get_subset_`` could produce ambiguous meaning sub-dataset which

conflicts with original dataset.

Examples:

>>> # Assume the annotation file is given above.

>>> class CustomDataset(BaseDataset):

>>> METAINFO: dict = dict(task_name='custom_task',

>>> dataset_type='custom_type')

>>> metainfo=dict(task_name='custom_task_name')

>>> custom_dataset = CustomDataset(

>>> 'path/to/ann_file',

>>> metainfo=metainfo)

>>> # meta information of annotation file will be overwritten by

>>> # `CustomDataset.METAINFO`. The merged meta information will

>>> # further be overwritten by argument `metainfo`.

>>> custom_dataset.metainfo

{'task_name': custom_task_name, dataset_type: custom_type}

"""

METAINFO: dict = dict()

_fully_initialized: bool = False

def __init__(self,

ann_file: Optional[str] = '',

metainfo: Union[Mapping, Config, None] = None,

data_root: Optional[str] = '',

data_prefix: dict = dict(img_path=''),

filter_cfg: Optional[dict] = None,

indices: Optional[Union[int, Sequence[int]]] = None,

serialize_data: bool = True,

pipeline: List[Union[dict, Callable]] = [],

test_mode: bool = False,

lazy_init: bool = False,

max_refetch: int = 1000):

self.ann_file = ann_file

self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))

self.data_root = data_root

self.data_prefix = copy.copy(data_prefix)

self.filter_cfg = copy.deepcopy(filter_cfg)

self._indices = indices

self.serialize_data = serialize_data

self.test_mode = test_mode

self.max_refetch = max_refetch

self.data_list: List[dict] = []

self.data_bytes: np.ndarray

# Join paths.

self._join_prefix()

# Build pipeline.

self.pipeline = Compose(pipeline)

# Full initialize the dataset.

if not lazy_init:

self.full_init()

@force_full_init

def get_data_info(self, idx: int) -> dict:

"""Get annotation by index and automatically call ``full_init`` if the

dataset has not been fully initialized.

序列化方式通过内存映射和反序列化,可能更适合处理大规模数据或减少内存占用,

而非序列化方式则更简单直接,适用于数据规模较小或内存资源充足的情况。

Args:

idx (int): The index of data.

Returns:

dict: The idx-th annotation of the dataset.

无论哪种方式,最后得到的 data_info 变量都包含了索引 idx 对应的数据。

- 序列化数据加载时,通过地址计算、内存视图和反序列化,从字节数组中提取数据;

- 非序列化数据加载时,直接从已存储的对象列表中复制所需数据。

两种方式适应了不同的存储场景和性能需求。

"""

if self.serialize_data:

start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()

end_addr = self.data_address[idx].item()

bytes = memoryview(

self.data_bytes[start_addr:end_addr]) # type: ignore

data_info = pickle.loads(bytes) # type: ignore

else:

data_info = copy.deepcopy(self.data_list[idx])

# Some codebase needs `sample_idx` of data information. Here we convert

# the idx to a positive number and save it in data information.

if idx >= 0:

data_info['sample_idx'] = idx

else:

data_info['sample_idx'] = len(self) + idx

return data_info

def full_init(self):

"""Load annotation file and set ``BaseDataset._fully_initialized`` to

True.

If ``lazy_init=False``, ``full_init`` will be called during the

instantiation and ``self._fully_initialized`` will be set to True. If

``obj._fully_initialized=False``, the class method decorated by

``force_full_init`` will call ``full_init`` automatically.

Several steps to initialize annotation:

- load_data_list: Load annotations from annotation file.

- filter data information: Filter annotations according to

filter_cfg.

- slice_data: Slice dataset according to ``self._indices``

- serialize_data: Serialize ``self.data_list`` if

``self.serialize_data`` is True.

"""

# check是不是 self._fully_initialized 和 self.serialize_data 不能同时为 true

if self._fully_initialized:

return

# load data information

self.data_list = self.load_data_list()

# filter illegal data, such as data that has no annotations.

self.data_list = self.filter_data()

# Get subset data according to indices.

if self._indices is not None:

self.data_list = self._get_unserialized_subset(self._indices)

# serialize data_list

if self.serialize_data:

self.data_bytes, self.data_address = self._serialize_data()

self._fully_initialized = True

BaseDataset类中定义了一些可能会影响内存使用的方法和属性,例如:

data_list:存储数据集所有样本的列表,每个样本都是一个字典,包含了图像路径、尺寸和实例信息等。serialize_data:一个布尔值,指示是否在初始化时将data_list序列化以节省内存。当启用时,数据加载器的工作进程可以使用主进程的共享RAM,而不是进行复制。_serialize_data和_get_serialized_subset:这些方法用于序列化和获取序列化数据的子集,这有助于在多进程数据加载时减少内存消耗。

在分布式训练中,如果每个GPU rank都加载完整的data_list,那么确实会导致内存的重复使用。为了解决这个问题,serialize_data属性被设置为True时,可以通过序列化数据来节省内存,这样每个工作进程就可以共享主进程的RAM,而不是各自复制一份数据。

serialize_data

在多进程数据加载的场景下,比如使用PyTorch的DataLoader时,每个工作进程(worker)通常需要加载数据集的一部分来并行处理。如果没有序列化处理,每个工作进程都会复制一份完整的data_list到自己的内存空间中,这会导致内存的大量重复使用,特别是在数据集很大的情况下。

通过serialize_data参数启用序列化后,数据集的样本信息会被转换成一个二进制格式的数组(data_bytes),并且每个样本信息的起始和结束位置会被记录在一个地址数组(data_address)中。这样,当数据加载器的工作进程需要获取数据时,它们可以直接从共享的data_bytes数组中按地址提取所需的样本信息,而无需复制整个数据列表。这意味着所有的工作进程都可以直接使用主进程中的共享内存,从而大大减少了内存的使用。

进一步理解 serialize data

用一个餐厅的比喻来理解serialize_data的概念。

你经营一家非常受欢迎的餐厅,这家餐厅的菜单上有100道菜。每天,你都需要为顾客提供这些菜,但是每道菜的需求量是不同的。为了高效地为顾客服务,你有两种选择:

不序列化(serialize_data=False): 这就像你在餐厅里为每个服务员准备一份完整的菜单,每份菜单上都有100道菜。每天早上,服务员们会从厨房领取他们需要的所有食材,准备一天的工作。这意味着每个服务员都需要携带大量的食材,而且厨房也需要准备足够的食材来满足所有服务员的需求。这在餐厅规模较小、顾客较少时是可行的,但如果餐厅很大,或者顾客非常多,这就会导致厨房的食材库存压力巨大,效率低下。 序列化(serialize_data=True): 现在,你决定改变策略。厨房不再为每个服务员准备一份完整的菜单,而是将每道菜的食材打包成单独的小包裹,并在每个包裹上贴上标签,说明这是哪道菜的食材。服务员们只需要根据顾客的订单来领取对应的食材包裹。这样,厨房只需要准备足够的食材来满足所有顾客的总需求,而不是每个服务员的需求。服务员们也不需要携带大量的食材,他们只需要根据需要领取相应的包裹即可。这种方式大大减少了食材的浪费和厨房的存储压力,提高了服务效率。

在数据集处理的上下文中,serialize_data的作用就像上述例子中的食材打包。如果没有序列化,每个工作进程(服务员)都需要一份完整的数据集副本(完整的菜单),这会导致大量的内存占用和数据重复。启用序列化后,数据集的每个样本都被打包成一个二进制格式的“包裹”(data_bytes),并附有一个地址标签(data_address),工作进程只需要根据需要加载和处理这些“包裹”,而不是整个数据集,这样可以显著减少内存的使用,提高数据处理的效率。

精彩文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: