@@ -32,20 +32,23 @@ class UniformNegativeSampler(NegativeSampler):
3232 Examples
3333 --------
3434 >>> from dgl import graphbolt as gb
35- >>> indptr = torch.LongTensor([0, 2, 4, 5 ])
36- >>> indices = torch.LongTensor([1, 2, 0, 2 , 0])
35+ >>> indptr = torch.LongTensor([0, 1, 2, 3, 4 ])
36+ >>> indices = torch.LongTensor([1, 2, 3 , 0])
3737 >>> graph = gb.fused_csc_sampling_graph(indptr, indices)
38- >>> node_pairs = ( torch.tensor([0, 1]), torch.tensor( [1, 2]) )
38+ >>> node_pairs = torch.tensor([[ 0, 1], [1, 2], [2, 3], [3, 0]] )
3939 >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
4040 >>> item_sampler = gb.ItemSampler(
41- ... item_set, batch_size=1 ,)
41+ ... item_set, batch_size=4 ,)
4242 >>> neg_sampler = gb.UniformNegativeSampler(
4343 ... item_sampler, graph, 2)
4444 >>> for minibatch in neg_sampler:
4545 ... print(minibatch.negative_srcs)
4646 ... print(minibatch.negative_dsts)
47- (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
48- (tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
47+ None
48+ tensor([[2, 1],
49+ [2, 1],
50+ [3, 2],
51+ [1, 3]])
4952 """
5053
5154 def __init__ (
0 commit comments