I'm using PyTorch Geometric and have a custom Torus class (a subclass of Data) that conditionally initializes attributes such as edge_index, edge_attr, and pos only when no_fixed is False. For example:
class Torus(Data):
def __init__(self, N: int = 64, L: float = 64.0, M: int = 64,
x: torch.Tensor = None, y: torch.Tensor = None,
eight_neighbors: bool = False, device=None, no_fixed: bool = False):
if no_fixed:
super().__init__(x=x, y=y)
else:
pos = grid_pos(N, L, device)
edge_index, edge_attr = torus_edges(N, L, pos, eight_neighbors, device)
super().__init__(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
self.N = torch.tensor(N, dtype=torch.long).unsqueeze(0)
self.L = torch.tensor(L, dtype=torch.float).unsqueeze(0)
self.M = torch.tensor(M, dtype=torch.long).unsqueeze(0)
self.eight_neighbors = torch.tensor(eight_neighbors, dtype=torch.int).unsqueeze(0)
When I create individual Torus objects with no_fixed=True, they do not have the edge_index, edge_attr, or pos attributes. However, if I batch them using:
batch = Batch.from_data_list(torus_list)
the resulting batch unexpectedly contains edge_index, edge_attr, and pos (with shapes like [128^2, ...] etc).
Why does the batching process add these attributes to the batch even though none of the individual objects have them? Is this expected behavior due to how PyG’s Batch class works?
Any insights would be appreciated!