PyG Batching

内容来自这里.

高级 Mini-Batching(Mini-Batching)

创建 mini-batching 对于让深度学习模型的训练扩展到海量数据至关重要。mini-batch 不会一个接一个地处理样本,而是将一组样本分组到一个统一的表示中,从而可以高效地并行处理。在图像或语言领域,这个过程通常是通过将每个样本重新缩放或 padding 到一组等大小的形状来实现的,然后将样本分组到一个额外的维度中。这个维度的长度等于分组在一个 mini-batch 中的样本数量,通常称为 batch_size。

由于图是一种最通用的数据结构,可以包含任意数量的节点(nodes)或边(edges),因此上述两种方法要么不可行,要么可能导致大量不必要的内存消耗。在 PyG 中,我们采用另一种方法来实现对大量样本的并行化。在这里,adjacency matrices 以对角线方式堆叠(创建一个包含多个孤立子图的巨大图),并且节点和目标特征(features)简单地沿节点维度进行拼接,即:

$$ A = \begin{bmatrix} A_1 & & \\ & \ddots & \\ & & A_n \end{bmatrix}, \quad X = \begin{bmatrix} X_1 \\ \vdots \\ X_n \end{bmatrix}, \quad Y = \begin{bmatrix} Y_1 \\ \vdots \\ Y_n \end{bmatrix}. $$

与其他 batching 过程相比,此过程具有一些关键优势:

  • 依赖于 message passing scheme 的 GNN operators 不需要修改,因为属于不同图的两个节点之间仍然不能交换消息。
  • 没有计算或内存开销。例如,此 batching 过程完全不需要对节点或边的特征进行任何 padding。请注意,adjacency matrices 没有额外的内存开销,因为它们以稀疏(sparse)方式保存,只包含非零项,即边。

PyG 借助 torch_geometric.loader.DataLoader 类自动将多个图 batch 成一个巨大的图。在内部,DataLoader 只是一个常规的 PyTorch torch.utils.data.DataLoader,它重写了其 collate() 功能,即定义如何将样本列表分组在一起。因此,所有可以传递给 PyTorch DataLoader 的参数也可以传递给 PyG DataLoader,例如 worker 数量 num_workers

在其最一般的形式中,PyG DataLoader 会自动将 edge_index tensor 增加到当前处理图之前已聚合(collated)的所有图的累积节点数,并将 edge_index tensors(形状为 [2, num_edges])在第二个维度上进行拼接。face tensors,即 mesh 中的 face 索引,也是如此。所有其他 tensors 将仅在第一个维度上进行拼接,而不会进一步增加其值。

然而,存在一些特殊用例(如下所述),用户需要根据自己的需求主动修改此行为。PyG 允许通过重写 torch_geometric.data.Data.__inc__()torch_geometric.data.Data.__cat_dim__() 功能来修改底层的 batching 过程。在没有任何修改的情况下,它们在 Data 类中定义如下:

def __inc__(self, key, value, *args, kwargs):
    if 'index' in key:
        return self.num_nodes
    else:
        return 0

def __cat_dim__(self, key, value, *args, kwargs):
    if 'index' in key:
        return 1
    else:
        return 0

我们可以看到 __inc__() 定义了两个连续图属性之间的增量计数。默认情况下,只要属性名称包含子字符串 index(出于历史原因),PyG 就会将属性增加节点数量 num_nodes,这对于 edge_indexnode_index 等属性非常方便。但是请注意,这可能会导致属性名称包含子字符串 index 但不应增加的属性出现意外行为。为确保正确,最佳实践是始终仔细检查 batching 的输出。此外,__cat_dim__() 定义了相同属性的图 tensors 应在哪个维度上进行拼接。这两个函数都会为存储在 Data 类中的每个属性调用,并将它们的特定 keyvalue 项作为参数传递。

接下来,我们将介绍一些可能绝对需要修改 __inc__()__cat_dim__() 的用例。

图对(Pairs of Graphs)

如果您想在单个 Data object 中存储多个图,例如用于图匹配(graph matching)等应用,则需要确保所有这些图的 batching 行为正确。例如,考虑在 Data 中存储两个图:一个源图 $G_s$ 和一个目标图 $G_t$,例如:

from torch_geometric.data import Data

class PairData(Data):
    pass

data = PairData(x_s=x_s, edge_index_s=edge_index_s,  # Source graph.
                x_t=x_t, edge_index_t=edge_index_t)  # Target graph.

在这种情况下,edge_index_s 应该增加源图 $G_s$ 中的节点数,例如 x_s.size(0),而 edge_index_t 应该增加目标图 $G_t$ 中的节点数,例如 x_t.size(0)

class PairData(Data):
    def __inc__(self, key, value, *args, kwargs):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        return super().__inc__(key, value, *args, kwargs)

我们可以通过设置一个简单的测试脚本来测试我们的 PairData batching 行为:

from torch_geometric.loader import DataLoader
import torch

x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])

x_t = torch.randn(4, 16)  # 4 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])

data = PairData(x_s=x_s, edge_index_s=edge_index_s,
                x_t=x_t, edge_index_t=edge_index_t)

data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_t=[8, 16], edge_index_t=[2, 6])
print(batch.edge_index_s)
# >>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
#             [1, 2, 3, 4, 6, 7, 8, 9]])
print(batch.edge_index_t)
# >>> tensor([[0, 0, 0, 4, 4, 4],
#             [1, 2, 3, 5, 6, 7]])

到目前为止一切顺利!即使 $G_s$ 和 $G_t$ 使用不同数量的节点,edge_index_sedge_index_t 也能正确地 batch 到一起。然而,batch 属性(将每个节点映射到其各自的图)丢失了,因为 PyG 无法识别 PairData object 中的实际图。这就是 DataLoaderfollow_batch 参数发挥作用的地方。在这里,我们可以指定要为哪些属性维护 batch 信息:

loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))

print(batch)
# >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_s_batch=[10], x_t=[8, 16], edge_index_t=[2, 6], x_t_batch=[8])
print(batch.x_s_batch)
# >>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print(batch.x_t_batch)
# >>> tensor([0, 0, 0, 0, 1, 1, 1, 1])

正如所见,follow_batch=['x_s', 'x_t'] 现在成功地为节点特征 x_sx_t 分别创建了分配向量 x_s_batchx_t_batch。现在可以使用这些信息在单个 Batch object 中对多个图执行 reduce 操作,例如 global pooling。

二分图(Bipartite Graphs)

bipartite graph 的 adjacency matrix 定义了两种不同节点类型之间的关系。通常,每种节点类型的节点数量不必匹配,从而导致一个形状为 $A \in \{0, 1\}^{N \times M}$ 且可能 $N \ne M$ 的非方阵 adjacency matrix。在 bipartite graphs 的 mini-batching 过程中,edge_index 中边的源节点应以不同于 edge_index 中边的目标节点的方式增加。为了实现这一点,考虑一个介于两种节点类型之间的 bipartite graph,分别具有相应的节点特征 x_sx_t

from torch_geometric.data import Data

class BipartiteData(Data):
    pass

data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)

对于 bipartite graphs 中正确的 mini-batching 过程,我们需要告诉 PyG 它应该独立地增加 edge_index 中边的源节点和目标节点:

class BipartiteData(Data):
    def __inc__(self, key, value, *args, kwargs):
        if key == 'edge_index':
            return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
        return super().__inc__(key, value, *args, kwargs)

在这里,edge_index[0](边的源节点)增加了 x_s.size(0),而 edge_index[1](边的目标节点)增加了 x_t.size(0)。我们可以再次通过运行一个简单的测试脚本来测试我们的实现:

from torch_geometric.loader import DataLoader
import torch

x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.

edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])

data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# >>> BipartiteDataBatch(x_s=[4, 16], x_t=[6, 16], edge_index=[2, 8])
print(batch.edge_index)
# >>> tensor([[0, 0, 1, 1, 2, 2, 3, 3],
#             [0, 1, 1, 2, 3, 4, 4, 5]])

再次,这正是我们想要的行为!

沿新维度进行 Batching(Batching Along New Dimensions)

有时,Data object 的属性应该通过获得一个新的 batch dimension 来进行 batching(如经典 mini-batching 中那样),例如对于图级别的属性或目标。具体来说,形状为 [num_features] 的属性列表应作为 [num_examples, num_features] 返回,而不是 [num_examples * num_features]。PyG 通过在 __cat_dim__() 中返回一个 None 的拼接维度来实现这一点:

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

class MyData(Data):
    def __cat_dim__(self, key, value, *args, kwargs):
        if key == 'foo':
            return None
        return super().__cat_dim__(key, value, *args, kwargs)

edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1],
])

foo = torch.randn(16)
data = MyData(num_nodes=3, edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# >>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16])

如预期,batch.foo 现在由两个维度描述:batch dimension 和 feature dimension。