-
Notifications
You must be signed in to change notification settings - Fork 88
[Model] CoBFormer #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bearsticker2222
wants to merge
7
commits into
BUPT-GAMMA:main
Choose a base branch
from
bearsticker2222:wqytest
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[Model] CoBFormer #233
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
79774f6
first submit
cce1e00
first commit cobformer
3fdefb6
fix bugs
gyzhou2000 7c45407
Merge remote-tracking branch 'upstream/main' into wqytest
gyzhou2000 b6128b4
add readme, modified main,loss and optimizer part
652e731
add readme, modified main,loss and optimizer part
5f8a95c
modify code
gyzhou2000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| import os | ||
| # os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
| # os.environ['TL_BACKEND'] = 'torch' | ||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
| # 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR | ||
|
|
||
| import argparse | ||
| import tensorlayerx as tlx | ||
| import numpy as np | ||
| from sklearn.metrics import f1_score | ||
| from partition import partition_patch | ||
| from gammagl.datasets import Planetoid | ||
| from gammagl.models.cobformer import CoBFormer | ||
| from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
|
||
|
|
||
| def eval_f1(pred, label, num_classes): | ||
| pred = tlx.convert_to_numpy(pred) | ||
| label = tlx.convert_to_numpy(label) | ||
| micro = f1_score(label, pred, average='micro') | ||
| macro = f1_score(label, pred, average='macro') | ||
| return micro, macro | ||
|
|
||
|
|
||
| class CoLoss(WithLoss): | ||
| def __init__(self, model, loss_fn): | ||
| super(CoLoss, self).__init__(backbone=model, loss_fn=loss_fn) | ||
| self.alpha = model.alpha | ||
| self.tau = model.tau | ||
|
|
||
| def forward(self, data, label): | ||
| pred1, pred2 = self.backbone_network(data['x'], data['patch'], data['edge_index'], edge_weight=data['edge_weight'], num_nodes=data['num_nodes']) | ||
| l1 = tlx.losses.softmax_cross_entropy_with_logits(pred1[data['train_mask']], label[data['train_mask']]) | ||
| l2 = tlx.losses.softmax_cross_entropy_with_logits(pred2[data['train_mask']], label[data['train_mask']]) | ||
|
|
||
| pred1_scaled = pred1 * self.tau | ||
| pred2_scaled = pred2 * self.tau | ||
|
|
||
| l3 = tlx.losses.softmax_cross_entropy_with_logits(pred1_scaled[~data['train_mask']], tlx.nn.Softmax()(pred2_scaled)[~data['train_mask']]) | ||
| l4 = tlx.losses.softmax_cross_entropy_with_logits(pred2_scaled[~data['train_mask']], tlx.nn.Softmax()(pred1_scaled)[~data['train_mask']]) | ||
|
|
||
| return self.alpha * (l1 + l2) + (1 - self.alpha) * (l3 + l4) | ||
|
|
||
|
|
||
| def calculate_acc(logits, y, metrics): | ||
| """ | ||
| Args: | ||
| logits: node logits | ||
| y: node labels | ||
| metrics: tensorlayerx.metrics | ||
|
|
||
| Returns: | ||
| rst | ||
| """ | ||
|
|
||
| metrics.update(logits, y) | ||
| rst = metrics.result() | ||
| metrics.reset() | ||
| return rst | ||
|
|
||
|
|
||
| def main(args): | ||
| # load datasets | ||
| # set_device(5) | ||
| if str.lower(args.dataset) not in ['cora','pubmed','citeseer']: | ||
| raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
| dataset = Planetoid(args.dataset) | ||
| graph = dataset[0] | ||
|
|
||
| graph.train_mask = tlx.convert_to_numpy(graph.train_mask) | ||
| graph.val_mask = tlx.convert_to_numpy(graph.val_mask) | ||
| graph.test_mask = tlx.convert_to_numpy(graph.test_mask) | ||
| # Pad a dimension with value 0 at the end of each mask (1D array) using np.pad(mask, (0, 1), mode='constant') | ||
| graph.train_mask = np.pad(graph.train_mask, (0, 1), mode='constant') | ||
| graph.val_mask = np.pad(graph.val_mask, (0, 1), mode='constant') | ||
| graph.test_mask = np.pad(graph.test_mask, (0, 1), mode='constant') | ||
|
|
||
| patch = partition_patch(graph, args.n_patch) | ||
|
|
||
| # try: | ||
| # patch_copy = tlx.cast(patch, dtype=tlx.int64) | ||
| # except: | ||
| # patch_copy = tlx.convert_to_tensor(patch, dtype=tlx.int64) | ||
|
|
||
| # Convert label to one-hot encoding and cast to float type | ||
| label = tlx.nn.OneHot(dataset.num_classes)(graph.y) | ||
| label = tlx.cast(label, dtype=tlx.float32) | ||
|
|
||
| model = CoBFormer(graph.num_nodes, dataset.num_node_features, args.num_hidden, dataset.num_classes, layers=args.num_layers, | ||
| gcn_layers=args.gcn_layers, n_head=args.n_head, alpha=args.alpha, tau=args.tau, use_patch_attn=args.use_patch_attn) | ||
|
|
||
| optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) | ||
| train_weights = model.trainable_weights | ||
|
|
||
| loss_func = CoLoss(model, tlx.losses.softmax_cross_entropy_with_logits) | ||
| train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
|
||
| data = { | ||
| "x": graph.x, | ||
| "y": graph.y, | ||
| "edge_index": graph.edge_index, | ||
| "edge_weight": None, | ||
| "train_mask": graph.train_mask, | ||
| "test_mask": graph.test_mask, | ||
| "val_mask": graph.val_mask, | ||
| "num_nodes": graph.num_nodes, | ||
| 'train': graph.train_mask, | ||
| 'valid': graph.val_mask, | ||
| 'test': graph.test_mask, | ||
| 'patch': patch | ||
| } | ||
|
|
||
| # best_val_acc = 0 | ||
| for epoch in range(args.n_epoch): | ||
| model.set_train() | ||
| loss = train_one_step(data, label) | ||
| model.set_eval() | ||
|
|
||
| pred1, pred2 = model(data['x'], data['patch'], data['edge_index'], edge_weight=data['edge_weight'], num_nodes=data['num_nodes']) | ||
|
|
||
| y = data['y'] | ||
|
|
||
| num_classes = int(tlx.reduce_max(y) + 1) | ||
|
|
||
| y1_ = tlx.argmax(pred1, axis=1) | ||
|
|
||
| micro_val1, macro_val1 = eval_f1(y1_[data['valid']], y[data['valid']], num_classes) | ||
| # micro_test1, macro_test1 = eval_f1(y1_[data['test']], y[data['test']], num_classes) | ||
|
|
||
| y2_ = tlx.argmax(pred2, axis=1) | ||
| if len(y2_.shape) > 1: | ||
| y2_ = y2_.view(-1) | ||
|
|
||
| micro_val2, macro_val2 = eval_f1(y2_[data['valid']], y[data['valid']], num_classes) | ||
| # micro_test2, macro_test2 = eval_f1(y2_[data['test']], y[data['test']], num_classes) | ||
|
|
||
| print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
| + " train loss: {:.4f}".format(loss.item())\ | ||
| + " GCN micro_val acc: {:.4f}".format(micro_val1)\ | ||
| + " GCN macro_val acc: {:.4f}".format(macro_val1)\ | ||
| + " COB micro_val acc: {:.4f}".format(micro_val2)\ | ||
| + " COB macro_val acc: {:.4f}".format(macro_val2)) | ||
|
|
||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| # parameters setting | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--dataset', type=str, default='cora', help='dataset') | ||
| parser.add_argument('--lr', type=float, default=0.01) | ||
| parser.add_argument("--l2_coef", type=float, default=5e-4, help="l2 loss coeficient") | ||
| parser.add_argument('--gcn_wd', type=float, default=5e-4) | ||
| parser.add_argument('--num_hidden', type=int, default=64, help='Number of hidden units') | ||
| parser.add_argument('--num_layers', type=int, default=1, help='Number of layers') | ||
| parser.add_argument('--n_head', type=int, default=4, help='Number of attention heads') | ||
| parser.add_argument('--n_epoch', type=int, default=500, help='Number of training epochs') | ||
| parser.add_argument('--use_patch_attn', action='store_true', help='transformer use patch attention') | ||
| parser.add_argument('--show_details', type=bool, default=True) | ||
| parser.add_argument('--gcn_layers', type=int, default=2) | ||
| parser.add_argument('--n_patch', type=int, default=112) | ||
| parser.add_argument('--batch_size', type=int, default=100000) | ||
| parser.add_argument('--train_prop', type=float, default=.6) | ||
| parser.add_argument('--valid_prop', type=float, default=.2) | ||
| parser.add_argument('--alpha', type=float, default=.8) | ||
| parser.add_argument('--tau', type=float, default=.3) | ||
| parser.add_argument('--gpu', type=int, default=0) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if args.gpu >= 0: | ||
| tlx.set_device("GPU", args.gpu) | ||
| else: | ||
| tlx.set_device("CPU") | ||
|
|
||
| main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import tensorlayerx as tlx | ||
| import numpy as np | ||
| import networkx as nx | ||
| import metis | ||
|
|
||
| def partition_patch(graph, n_patches, load_path=None): | ||
|
|
||
| if load_path is not None: | ||
| # 使用 numpy 加载数据,假设数据是保存为 .npy 文件 | ||
| patch = np.load(load_path) | ||
| patch = tlx.convert_to_tensor(patch, dtype=tlx.int64) | ||
| else: | ||
| if n_patches == 1: | ||
| patch = np.arange(graph.num_nodes + 1) | ||
| patch = tlx.convert_to_tensor(patch, dtype=tlx.int64) | ||
| patch = tlx.expand_dims(patch, axis=0) | ||
| else: | ||
| patch = metis_partition(g=graph, n_patches=n_patches) | ||
|
|
||
| print('metis done!!!') | ||
|
|
||
| print('patch done!!!') | ||
|
|
||
| # Graph update operations | ||
|
|
||
| # torch版本的pad,可对比看是否使用正确 | ||
| graph.num_nodes += 1 | ||
| ''' | ||
| torch版本 | ||
| graph.x = F.pad(graph.x, [0, 0, 0, 1]) | ||
| # label = F.pad(label, [0, 1]) | ||
| graph.y = F.pad(graph.y, [0, 1]) | ||
| ''' | ||
| # 对x进行padding | ||
| x_shape = graph.x.shape | ||
| padded_x = np.pad(tlx.convert_to_numpy(graph.x), | ||
| pad_width=((0, 1), (0, 0)), | ||
| mode='constant', | ||
| constant_values=0) | ||
| graph.x = tlx.convert_to_tensor(padded_x) | ||
|
|
||
| # 对y进行padding | ||
| padded_y = np.pad(tlx.convert_to_numpy(graph.y), | ||
| pad_width=(0, 1), | ||
| mode='constant', | ||
| constant_values=0) | ||
| graph.y = tlx.convert_to_tensor(padded_y) | ||
|
|
||
| return patch | ||
|
|
||
| def metis_partition(g, n_patches=50): | ||
|
|
||
| if g.num_nodes < n_patches: | ||
| # 如果节点数小于需要的分割数,则直接随机分配 | ||
| membership = np.random.permutation(n_patches) | ||
| membership = tlx.convert_to_tensor(membership, dtype=tlx.int64) | ||
| else: | ||
| # 如果节点数大于或等于分割数,使用 METIS 进行分割 | ||
| adjlist = g.edge_index.T # 获取边的邻接列表 | ||
| G = nx.Graph() # 创建一个空的无向图 | ||
| G.add_nodes_from(np.arange(g.num_nodes)) # 添加节点 | ||
| G.add_edges_from(adjlist.tolist()) # 添加边 | ||
|
|
||
| # 使用 METIS 分割图 | ||
| cuts, membership = metis.part_graph(G, n_patches, recursive=True) | ||
|
|
||
| # 确保每个节点的归属部分数量不小于节点数 | ||
| assert len(membership) >= g.num_nodes | ||
| membership = tlx.convert_to_tensor(membership[:g.num_nodes], dtype=tlx.int64) | ||
|
|
||
| patch = [] # 用于存储每个分割部分的节点索引 | ||
| max_patch_size = -1 # 用于记录最大的子图大小 | ||
|
|
||
| for i in range(n_patches): | ||
| patch.append(list()) | ||
| # 使用 numpy 的 np.where 来代替 torch.where | ||
| patch[-1] = np.where(tlx.convert_to_numpy(membership == i))[0].tolist() # 归属到 i 号部分的节点 | ||
| max_patch_size = max(max_patch_size, len(patch[-1])) # 更新最大的子图大小 | ||
|
|
||
| # 填充所有子图,使它们的大小一致 | ||
| for i in range(len(patch)): | ||
| l = len(patch[i]) | ||
| if l < max_patch_size: | ||
| patch[i] += [g.num_nodes] * (max_patch_size - l) | ||
|
|
||
| patch = tlx.convert_to_tensor(patch, dtype=tlx.int64) # 返回最终的分割结果 | ||
|
|
||
| return patch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # Less is More: on the Over-Globalizing Problem in Graph Transformers (CoBformer) | ||
|
|
||
| - Paper link: [Less is More: on the Over-Globalizing Problem in Graph Transformers](http://arxiv.org/abs/2405.01102) | ||
|
|
||
| ## Dataset Statistics | ||
|
|
||
| Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). | ||
|
|
||
| ## Results | ||
| ```bash | ||
| # Cora | ||
| python cobformer_trainer.py --dataset=Cora --learning_rate=0.01 --gcn_wd=1e-3 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=112 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=0 | ||
|
|
||
| # CiteSeer | ||
| python cobformer_trainer.py --dataset=CiteSeer --learning_rate=5e-3 --gcn_wd=1e-2 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=144 --use_patch_attn --alpha=0.8 --tau=0.7 --gpu_id=0 | ||
|
|
||
| # PubMed | ||
| python cobformer_trainer.py --dataset=PubMed --learning_rate=5e-3 --gcn_wd=1e-3 --weight_decay=1e-3 --gcn_type=1 --gcn_layers=2 --n_patch=224 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=0 | ||
|
|
||
| ``` | ||
|
|
||
| | Dataset | Paper | Our(tf) | | ||
| | -------- | ----- | ----------- | | ||
| | cora | 85.28 | 83.16 ± 0.59 | | ||
| | citeseer | 74.52 | 71.20 ± 0.80 | | ||
| | pubmed | 81.42 | 81.84 ± 0.45 | | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import tensorlayerx as tlx | ||
| from .bga_layer import BGALayer | ||
|
|
||
|
|
||
| class BGA(tlx.nn.Module): | ||
| def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, | ||
| layers: int, n_head: int, use_patch_attn=True, dropout1=0.5, dropout2=0.1, need_attn=False): | ||
| super(BGA, self).__init__() | ||
| self.layers = layers | ||
| self.n_head = n_head | ||
| self.num_nodes = num_nodes | ||
| self.dropout = tlx.nn.Dropout(p=dropout1) | ||
| # 初始化线性层,直接使用nn.Linear | ||
| self.attribute_encoder = tlx.nn.Linear(in_features=in_channels, out_features=hidden_channels) | ||
| self.BGALayers = tlx.nn.ModuleList() | ||
| for _ in range(0, layers): | ||
| self.BGALayers.append( | ||
| BGALayer(n_head, hidden_channels, use_patch_attn, dropout=dropout2)) | ||
| self.classifier = tlx.nn.Linear(in_features=hidden_channels, out_features=out_channels) | ||
| self.attn = [] | ||
|
|
||
| def forward(self, x, patch, need_attn=False): | ||
| patch_mask = tlx.cast(patch != self.num_nodes - 1, dtype=tlx.float32) | ||
| patch_mask = tlx.expand_dims(patch_mask, axis=-1) | ||
| attn_mask = tlx.cast(tlx.matmul(patch_mask, tlx.transpose(patch_mask, perm=[0, 2, 1])), dtype=tlx.int32) | ||
|
|
||
| x = tlx.relu(self.attribute_encoder(x)) | ||
|
|
||
| for i in range(0, self.layers): | ||
| x = self.BGALayers[i](x, patch, attn_mask, need_attn) | ||
| if need_attn: | ||
| self.attn.append(self.BGALayers[i].attn) | ||
| x = self.dropout(x) | ||
| x = self.classifier(x) | ||
| return x |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考其他模型的readme文件,写一下数据集描述,运行的命令,结果