Skip to content

Commit

Permalink
fix mask issue and deal with come comment
Browse files Browse the repository at this point in the history
  • Loading branch information
CfromBU committed Jan 7, 2025
1 parent 46145ff commit 5089e3a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
11 changes: 8 additions & 3 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def collate(self, items):
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name, gb_padding=0):
def add_edge_attribute_to_graph(g, data_name, gb_padding):
"""Add data into the graph as an edge attribute.
For some cases such as prob/mask-based sampling on GraphBolt partitions,
Expand Down Expand Up @@ -348,6 +348,8 @@ class NodeCollator(Collator):
The neighborhood sampler.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.
e.g. some edges of specific types have no mask, the mask will be set as gb_padding.
the edge will not be sampled if the mask is 0.
Examples
--------
Expand Down Expand Up @@ -512,6 +514,10 @@ class EdgeCollator(Collator):
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, the edge will be sampled.
Examples
--------
Expand Down Expand Up @@ -616,7 +622,7 @@ def __init__(
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=0,
gb_padding=1,
):
self.g = g
if not isinstance(eids, Mapping):
Expand Down Expand Up @@ -869,7 +875,6 @@ def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
else:
dataloader_kwargs[k] = v

collator_kwargs["gb_padding"] = 1
if device is None:
# for the distributed case default to the CPU
device = "cpu"
Expand Down
9 changes: 7 additions & 2 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names, padding=0):
def __init__(self, name, kv_names, padding):
self._name = name
self._kv_names = kv_names
self._padding = padding
Expand All @@ -170,6 +170,11 @@ def process_request(self, server_state):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges

# Padding is used here to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.
# In DGLGraph, some edges may not have attributes or their values could be None.
# GraphBolt, however, samples edges based on specific attributes (like 'mask' == 1), so we pad the missing attributes with default values (e.g., 1 for 'mask')
# to ensure that all edges can be sampled consistently, regardless of whether their attributes are available in the DGLGraph.
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
Expand Down Expand Up @@ -1621,7 +1626,7 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name, padding=0):
def add_edge_attribute(self, name, padding):
"""Add an edge attribute into GraphBolt partition from edge data.
Parameters
Expand Down
6 changes: 5 additions & 1 deletion tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import unittest
from pathlib import Path

import backend as F
import dgl

import dgl.backend as F
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1919,6 +1920,9 @@ def check_hetero_dist_edge_dataloader_gb(

block = next(iter(loader))[2][0]
assert block.num_src_nodes("n1") > 0
assert block.num_edges("r12") > 0
assert block.num_edges("r13") > 0
assert block.num_edges("r23") > 0


def test_hetero_dist_edge_dataloader_gb(
Expand Down

0 comments on commit 5089e3a

Please sign in to comment.