SGL-PT: A Strong Graph Learner with Graph Prompt Tuning
Yun Zhu, Jianhao Guo, Siliang Tang

TL;DR
This paper introduces SGL-PT, a novel graph learning framework that combines strong pre-training with prompt tuning to improve downstream task performance, especially in biological datasets, by bridging the gap between pretext and target tasks.
Contribution
It proposes a universal pre-training task called SGL and a verbalizer-free prompt tuning method to unify pre-training and downstream tasks in graph learning.
Findings
Outperforms baseline methods in unsupervised settings
Enhances model performance on biological graph datasets
Bridges pre-training and fine-tuning gap effectively
Abstract
Recently, much exertion has been paid to design graph self-supervised methods to obtain generalized pre-trained models, and adapt pre-trained models onto downstream tasks through fine-tuning. However, there exists an inherent gap between pretext and downstream graph tasks, which insufficiently exerts the ability of pre-trained models and even leads to negative transfer. Meanwhile, prompt tuning has seen emerging success in natural language processing by aligning pre-training and fine-tuning with consistent training objectives. In this paper, we identify the challenges for graph prompt tuning: The first is the lack of a strong and universal pre-training task across sundry pre-training methods in graph domain. The second challenge lies in the difficulty of designing a consistent training objective for both pre-training and downstream tasks. To overcome above obstacles, we propose a novel…
| Datasets | type | # graphs | Avg # nodes | Avg # edges |
| PROTEINS | Biological | 1113 | 39.06 | 72.82 |
| DD | Biological | 1178 | 284.32 | 715.66 |
| MUTAG | Biological | 188 | 17.93 | 19.79 |
| NCI1 | Biological | 4110 | 29.87 | 32.30 |
| NCI-H23 | Biological | 40353 | 26.07 | 28.10 |
| P388 | Biological | 41472 | 22.11 | 23.56 |
| MOLT-4 | Biological | 39765 | 26.10 | 28.14 |
| IMDB-B | Social | 1000 | 19.77 | 96.53 |
| IMDB-M | Social | 1500 | 13.00 | 65.94 |
| COLLAB | Social | 5000 | 74.49 | 2457.78 |
| REDDIT-B | Social | 2000 | 508.52 | 594.87 |
| REDDIT-M12K | Social | 11929 | 391.41 | 456.89 |
| Methods | PROTEINS | DD | NCI1 | MUTAG | IMDB-B | IMDB-M | COLLAB | REDDIT-B | |
| Supervised | GCN | 74.9±3.3 | 75.9±2.5 | 80.2±2.0 | 85.6±5.8 | 70.4±3.4 | 51.9±3.8 | 79.0±1.8 | - |
| GIN | 76.2±2.8 | 75.3±2.9 | 82.7±1.7 | 89.4±5.6 | 75.1±5.1 | 52.3±2.8 | 80.2±1.9 | 92.4±2.5 | |
| DiffPool | 75.1±3.5 | - | - | 85.0±10.3 | 72.6±3.9 | - | 78.9±2.3 | 92.1±2.6 | |
| Graph Kernels | GL | 71.67±0.55 | 72.54±3.83 | - | 81.66±2.11 | 65.87±0.98 | 43.89±0.38 | 56.30±0.60 | 77.34±0.18 |
| WL | 72.92±0.56 | 79.78±0.36 | 80.01±0.50 | 80.72±3.00 | 72.30±3.44 | 46.95±0.46 | 69.30±3.44 | 68.82±0.41 | |
| DGK | 73.30±0.82 | 73.50±1.01 | 80.31±0.46 | 87.44±2.72 | 66.96±0.56 | 44.55±0.52 | 64.66±0.50 | 78.04±0.39 | |
| Self-supervised | sub2vec | 53.03±5.55 | 54.33±2.44 | 52.89±1.61 | 61.05±15.79 | 55.26±1.54 | 36.67±0.83 | 55.26±1.54 | 71.48±0.41 |
| graph2vec | 73.30±2.05 | 70.32±2.32 | 73.22±1.81 | 83.15±9.25 | 71.10±0.54 | 50.44±0.87 | 71.10±0.54 | 75.78±1.03 | |
| EdgePred | 73.12±1.54 | 72.34±1.04 | 74.41±1.50 | 84.49±1.56 | 68.48±1.11 | 44.83±0.65 | 64.80±1.16 | 84.48±0.68 | |
| Infograph | 74.44±0.31 | 72.85±1.78 | 76.20±1.06 | 89.01±1.13 | 73.03±0.87 | 49.69±0.53 | 70.65±1.13 | 82.50±1.42 | |
| GraphCL | 74.39±0.45 | 78.62±0.40 | 77.87±0.41 | 86.80±1.34 | 71.14±0.44 | 48.58±0.67 | 71.36±1.15 | 89.53±0.84 | |
| JOAO | 74.55±0.41 | 77.32±0.54 | 78.07±0.47 | 87.35±1.02 | 70.21±3.08 | 49.20±0.77 | 69.50±0.36 | 85.29±1.35 | |
| SimGRACE | 75.35±0.09 | 77.44±1.11 | 79.12±0.44 | 89.01±1.31 | 71.30±0.77 | - | 71.72±0.82 | 89.51±0.89 | |
| MVGRL | - | - | - | 89.70±1.10 | 74.20±0.70 | 51.20±0.50 | - | 84.50±0.60 | |
| InfoGCL | - | - | 80.20±0.60 | 91.20±1.30 | 75.10±0.90 | 51.40±0.80 | 80.00±1.30 | - | |
| GraphMAE | 75.30±0.39 | 79.42±0.42 | 80.40±0.30 | 88.19±1.26 | 75.52±0.66 | 51.63±0.52 | 80.32±0.46 | 88.01±0.19 | |
| SGL | 76.55±0.19 | 80.54±0.65 | 80.91±0.42 | 88.83±1.44 | 75.88±0.47 | 52.84±0.26 | 80.80±0.23 | 89.97±0.48 |
| NCI-H23 | MOLT-4 | P388 | RDT-M12K | |
| GIN | 78.41±0.36 | 72.40±1.07 | 82.63±0.90 | 35.01±1.48 |
| EdgePred | 72.49±0.72 | 70.53±1.98 | 72.28±0.97 | 28.65±1.48 |
| Infograph | 78.69±0.69 | 71.29±0.53 | 78.00±1.38 | 33.14±0.74 |
| GraphCL | 76.37±1.08 | 73.26±1.14 | 79.92±1.21 | 33.79±2.47 |
| JOAO | 76.69±1.13 | 72.49±0.74 | 79.86±1.96 | 35.76±1.42 |
| MVGRL | 78.08±0.46 | 74.63±0.32 | 80.20±0.89 | 32.21±1.35 |
| GraphMAE | 77.09±0.33 | 73.91±0.51 | 80.55±0.48 | 33.77±1.35 |
| SGL | 79.98±0.53 | 75.10±0.52 | 81.75±0.47 | 36.24±1.83 |
| 10%L.R. | PROTEINS | DD | NCI1 | MUTAG | NCI-H23 | IMDB-B | IMDB-M | Imp | |
| FT | No pre-train. | 67.98±0.41 | 68.32±0.39 | 64.00±0.22 | 64.19±0.53 | 55.14±1.29 | 69.10±0.28 | 43.25±0.34 | 0.00 |
| EdgePred | 68.12±0.93 | 66.77±0.68 | 63.31±0.72 | 67.72±0.33 | 63.44±1.32 | 67.24±1.63 | 41.57±0.38 | 0.88 | |
| GraphCL | 68.58±0.84 | 68.60±0.49 | 68.08±0.48 | 64.76±0.65 | 65.77±0.52 | 69.95±1.83 | 43.39±0.62 | 2.45 | |
| GraphMAE | 68.79±0.77 | 68.70±0.31 | 65.19±0.48 | 71.16±0.83 | 64.25±1.44 | 69.46±0.33 | 44.09±0.49 | 2.81 | |
| 69.41±1.05 | 68.81±0.40 | 66.51±0.28 | 72.51±1.83 | 64.18±1.29 | 69.62±0.24 | 44.69±0.56 | 3.39 | ||
| PT | GPPT | 68.26±0.87 | 66.53±0.85 | 62.85±1.20 | 71.13±1.81 | 53.44±0.81 | 66.04±0.82 | 38.64±1.40 | 0.73 |
| 65.28±1.99 | 66.04±0.72 | 66.24±0.51 | 71.29±2.07 | 51.86±0.70 | 69.52±0.55 | 43.56±0.56 | 0.26 | ||
| GraphPrompt | 71.00±0.46 | 65.24±1.26 | 60.55±2.17 | 71.78±0.75 | 55.39±0.85 | 62.60±2.32 | 40.33±0.60 | 0.72 | |
| 71.26±0.43 | 67.07±0.50 | 64.47±0.27 | 70.42±0.54 | 60.61±2.51 | 69.88±0.23 | 43.44±0.31 | 2.16 | ||
| 67.19±0.75 | 67.97±0.50 | 67.37±0.55 | 71.85±1.04 | 63.24±1.51 | 69.60±0.43 | 43.44±0.43 | 2.67 | ||
| 66.88±0.90 | 67.76±0.80 | 64.59±0.40 | 69.16±1.52 | 68.33±0.60 | 70.18±0.54 | 44.23±0.53 | 2.73 | ||
| 72.94±0.24 | 75.37±0.16 | 68.80±0.18 | 80.07±1.37 | 69.71±0.15 | 70.40±0.77 | 45.28±0.41 | 7.31 |
| Few-shot | PROTEINS | DD | NCI1 | MUATG | NCI-H23 | ||||||
| 1-shot | 3-shot | 1-shot | 3-shot | 1-shot | 3-shot | 1-shot | 3-shot | 1-shot | 3-shot | ||
| FT | EdgePred | 57.25±2.80 | 58.36±2.06 | 50.88±2.85 | 52.02±1.21 | 50.75±0.30 | 50.94±0.56 | 57.01±4.83 | 58.71±4.04 | 51.94±1.85 | 52.45±2.35 |
| GraphCL | 59.40±1.87 | 59.91±2.88 | 52.30±2.51 | 53.39±2.87 | 51.51±0.28 | 51.72±0.55 | 57.92±3.43 | 62.39±1.53 | 52.28±3.13 | 54.85±2.14 | |
| GraphMAE | 58.46±3.75 | 59.29±1.62 | 51.31±1.63 | 54.14±1.15 | 51.45±1.45 | 52.67±1.86 | 57.67±4.11 | 60.92±3.54 | 50.40±1.58 | 53.72±1.61 | |
| 58.76±1.02 | 59.73±2.65 | 52.87±4.15 | 54.51±2.91 | 51.88±1.65 | 53.22±2.11 | 58.68±2.28 | 62.47±2.51 | 52.06±2.69 | 55.67±1.08 | ||
| PT | GPPT | 58.10±1.11 | 58.63±0.81 | 50.45±1.49 | 51.50±2.04 | 51.50±0.81 | 51.25±0.47 | 63.23±4.35 | 66.06±2.55 | 51.55±1.19 | 51.84±1.14 |
| 58.64±2.43 | 60.02±1.66 | 54.47±3.49 | 56.92±2.20 | 51.29±2.16 | 52.10±2.37 | 63.13±3.53 | 65.86±1.65 | 49.88±2.20 | 50.03±1.96 | ||
| GraphPrompt | 58.94±2.37 | 62.01±1.47 | 53.62±1.09 | 53.38±1.37 | 51.49±0.38 | 51.64±0.58 | 67.62±1.93 | 70.03±2.18 | 55.03±0.75 | 56.97±1.25 | |
| 59.95±2.71 | 62.93±1.37 | 55.03±2.40 | 58.21±0.95 | 52.65±1.06 | 52.50±1.21 | 68.22±2.86 | 71.40±3.12 | 52.90±2.47 | 52.57±1.89 | ||
| 57.76±2.07 | 59.32±2.43 | 53.58±2.99 | 54.24±2.15 | 51.68±0.40 | 51.91±0.88 | 60.15±4.98 | 61.93±3.03 | 51.19±2.93 | 54.20±3.77 | ||
| 57.99±1.47 | 60.99±0.96 | 52.75±3.27 | 54.32±1.57 | 51.65±1.04 | 52.03±0.35 | 61.37±2.78 | 62.42±1.70 | 50.48±1.27 | 52.66±1.94 | ||
| 61.02±2.63 | 64.47±2.10 | 57.15±1.92 | 61.12±2.42 | 53.08±2.06 | 55.24±1.28 | 72.88±5.20 | 78.60±1.81 | 55.61±1.74 | 58.76±2.09 | ||
| PROTEINS | DD | NCI1 | MUTAG | IMDB-B | IMDB-M | COLLAB | REDDIT-B | |
| Full | 76.55±0.19 | 80.54±0.65 | 80.91±0.42 | 88.83±1.44 | 75.88±0.47 | 52.84±0.26 | 80.80±0.23 | 89.97±0.48 |
| w/o local branch | 74.68±0.50 | 79.67±0.46 | 79.12±0.35 | 85.22±1.25 | 74.42±0.16 | 50.59±0.32 | 78.19±0.12 | 86.67±0.37 |
| w/o global branch | 75.18±0.40 | 79.54±0.54 | 80.18±0.28 | 86.81±2.12 | 75.24±0.55 | 51.63±0.52 | 80.33±0.38 | 87.83±0.11 |
| w/o dynamic queue | 74.85±0.43 | 78.50±0.46 | 80.37±0.15 | 85.23±1.13 | 74.89±0.38 | 51.06±0.71 | 80.51±0.29 | 86.70±0.50 |
| PROTEINS | DD | NCI1 | MUTAG | IMDB-B | IMDB-M | COLLAB | REDDIT-B | ZINC | ||
| Model configuration | hidden_size | 512 | 512 | 512 | 32 | 512 | 512 | 256 | 512 | 300 |
| num_layer | 3 | 2 | 2 | 5 | 2 | 3 | 2 | 2 | 5 | |
| activation | prelu | prelu | prelu | prelu | prelu | prelu | relu | prelu | relu | |
| norm | BN | BN | BN | BN | BN | BN | BN | LN | BN | |
| Local branch | scaling factor | 1 | 1 | 2 | 2 | 1 | 1 | 1 | 1 | 1 |
| masking rate | 0.5 | 0.1 | 0.25 | 0.75 | 20 | 0.5 | 0.75 | 0.75 | 0.25 | |
| replace rate | 0.0 | 0.1 | 0.1 | 0.1 | 0.001 | 0.0 | 0.0 | 0.1 | 0.0 | |
| Global branch | 1024 | 1024 | 1024 | 1024 | 1024 | 1024 | 1024 | 1024 | 4096 | |
| momentum | 0.995 | 0.999 | 0.999 | 0.999 | 0.999 | 0.995 | 0.999 | 0.999 | 0.999 | |
| tempurate | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 0.08 | 0.05 | |
| feat_mask1 | 0.4 | 0.1 | 0.0 | 0.2 | 0.2 | 0.0 | 0.2 | 0.3 | 0.0 | |
| feat_mask2 | 0.1 | 0.2 | 0.0 | 0.5 | 0.5 | 0.2 | 0.3 | 0.3 | 0.0 | |
| drop_edge1 | 0.0 | 0 | 0.0 | 0.0 | 0.1 | 0.0 | 0.0 | 0.0 | 0.0 | |
| drop_edge2 | 0.1 | 0 | 0.0 | 0.3 | 0.2 | 0.4 | 0.2 | 0.0 | 0.0 | |
| Training | batch_size | 32 | 32 | 16 | 64 | 32 | 32 | 32 | 8 | 256 |
| epochs | 100 | 80 | 300 | 22 | 60 | 50 | 20 | 120 | 100 | |
| learning rate | 0.00015 | 0.001 | 0.001 | 0.0005 | 0.00015 | 0.00015 | 0.00015 | 0.00015 | 0.001 | |
| weight_decay | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | |
| optimizer | Adam | Adam | Adam | Adam | Adam | Adam | Adam | Adam | Adam | |
| scheduler | False | True | True | False | False | False | True | False | False |
| NCI-H23 | MOLT-4 | P388 | RDT-M12K | ||
| Model | hidden_size | 128 | 128 | 512 | 32 |
| num_layer | 3 | 3 | 2 | 5 | |
| activation | prelu | prelu | prelu | prelu | |
| norm | BN | BN | BN | BN | |
| Local | scaling factor | 2 | 2 | 2 | 2 |
| masking rate | 0.25 | 0.25 | 0.25 | 0.75 | |
| replace rate | 0.1 | 0.1 | 0.1 | 0.1 | |
| Global | 1024 | 1024 | 1024 | 1024 | |
| momentum | 0.999 | 0.999 | 0.999 | 0.999 | |
| tempurate | 2 | 2 | 2 | 2 | |
| feat_mask1 | 0.0 | 0.1 | 0.2 | 0.2 | |
| feat_mask2 | 0.2 | 0.1 | 0.2 | 0.4 | |
| drop_edge1 | 0.0 | 0.0 | 0.0 | 0.0 | |
| drop_edge2 | 0.0 | 0.0 | 0.0 | 0.4 | |
| Training | batch_size | 16 | 32 | 16 | 32 |
| epochs | 100 | 100 | 100 | 100 | |
| learning rate | 0.0001 | 0.0001 | 0.0001 | 0.00015 | |
| weight_decay | 5e-4 | 5e-4 | 5e-4 | 0 | |
| optimizer | Adam | Adam | Adam | Adam | |
| scheduler | True | True | True | True |
| mask. | epochs | lr | bs | optimizer | |
| PROTEINS | 0.1 | 30 | 0.01 | 32 | Adam |
| DD | 0.1 | 50 | 0.0001 | 32 | Adam |
| NCI1 | 0.1 | 50 | 0.001 | 16 | Adam |
| MUTAG | 0.1 | 50 | 0.001 | 64 | Adam |
| NCI-H23 | 0.1 | 10 | 0.001 | 32 | Adam |
| MOLT-4 | 0.1 | 10 | 0.01 | 32 | Adam |
| P388 | 0.1 | 10 | 0.001 | 32 | Adam |
| IMDB-B | 0.1 | 20 | 0.001 | 32 | Adam |
| IMDB-M | 0.1 | 20 | 0.001 | 32 | Adam |
| ZINC | BBBP | Tox21 | ToxCast | SIDER | ClinTox | MUV | HIV | BACE | |
| # graphs | 2,000,000 | 2,039 | 7,831 | 8,576 | 1,427 | 1,477 | 93,087 | 41,127 | 1,513 |
| # binary prediction tasks | - | 1 | 12 | 617 | 27 | 2 | 17 | 1 | 1 |
| Avg. # nodes | 26.6 | 24.1 | 18.6 | 18.8 | 33.6 | 26.2 | 24.2 | 24.5 | 34.1 |
| BBBP | Tox21 | ToxCast | SIDER | ClinTox | MUV | HIV | BACE | Avg. | |
| No-pretrain | 65.5±1.8 | 74.3±0.5 | 63.3±1.5 | 57.2±0.7 | 58.2±2.8 | 71.7±2.3 | 75.4±1.5 | 70.0±2.5 | 67.0 |
| ContextPred | 64.3±2.8 | 75.7±0.7 | 63.9±0.6 | 60.9±0.6 | 65.9±3.8 | 75.8±1.7 | 77.3±1.0 | 79.6±1.2 | 70.4 |
| AttrMasking | 64.3±2.8 | 76.7±0.4 | 64.2±0.5 | 61.0±0.7 | 71.8±4.1 | 74.7±1.4 | 77.2±1.1 | 79.3±1.6 | 71.1 |
| Infomax | 68.8±0.8 | 75.3±0.5 | 62.7±0.4 | 58.4±0.8 | 69.9±3.0 | 75.3±2.5 | 76.0±0.7 | 75.9±1.6 | 70.3 |
| GraphCL | 69.7±0.7 | 73.9±0.7 | 62.4±0.6 | 60.5±0.9 | 76.0±2.7 | 69.8±2.7 | 78.5±1.2 | 75.4±1.4 | 70.8 |
| JOAO | 70.2±1.0 | 75.0±0.3 | 62.9±0.5 | 60.0±0.8 | 81.3±2.5 | 71.7±1.4 | 76.7±1.2 | 77.3±0.5 | 71.9 |
| GraphLoG | 72.5±0.8 | 75.7±0.5 | 63.5±0.7 | 61.2±1.1 | 76.7±3.3 | 76.0±1.1 | 77.8±0.8 | 83.5±1.2 | 73.4 |
| GraphMAE | 72.0±0.6 | 75.5±0.6 | 64.1±0.3 | 60.3±1.1 | 82.3±1.2 | 76.3±2.4 | 77.2±1.0 | 83.1±0.9 | 73.8 |
| SGL (Ours) | 72.6±0.4 | 76.7±0.4 | 64.3±0.2 | 62.6±0.4 | 83.3±0.9 | 79.8±1.3 | 78.7±0.4 | 84.3±0.4 | 75.3 |
| Dataset | # N | # E | # F | # C | H |
| Cora | 2078 | 5278 | 1433 | 7 | 0.81 |
| CiteSeer | 3327 | 4676 | 3703 | 6 | 0.74 |
| PubMed | 19717 | 44327 | 500 | 3 | 0.80 |
| Cora | CiteSeer | PubMed | ||
| Model | type | GAT | GAT | GAT |
| hidden_size | 512 | 512 | 1024 | |
| num_head | 4 | 2 | 4 | |
| num_layer | 2 | 2 | 2 | |
| activation | prelu | prelu | prelu | |
| norm | BN | BN | BN | |
| Local | scaling factor | 3 | 3 | 3 |
| masking rate | 0.5 | 0.5 | 0.75 | |
| replace rate | 0.05 | 0.1 | 0.0 | |
| Global | 1024 | 102400 | 1024 | |
| momentum | 0.999 | 0.999 | 0.999 | |
| tempurate | 2 | 2 | 2 | |
| feat_mask1 | 0.0 | 0.1 | 0.3 | |
| feat_mask2 | 0.1 | 0.5 | 0.3 | |
| drop_edge1 | 0.1 | 0.0 | 0.2 | |
| drop_edge2 | 0.5 | 0.1 | 0.2 | |
| loss coeff. | 0.1 | 0.08 | 0.1 | |
| Training | epochs | 1500 | 300 | 1000 |
| learning rate | 0.001 | 0.001 | 0.001 | |
| weight_decay | 2e-4 | 2e-5 | 1e-5 | |
| optimizer | Adam | Adam | Adam | |
| scheduler | True | True | True |
| Cora | CiteSeer | PubMed | |
| GraphMAE | 83.77±0.62 | 73.04±0.28 | 81.05±0.30 |
| SGL | 84.15±0.35 | 73.48±0.27 | 81.31±0.41 |
| Cora | PubMed | |||
| 3-shot | 5-shot | 3-shot | 5-shot | |
| GraphMAE | 68.50±1.67 | 72.48±1.08 | 68.80±0.36 | 73.08±0.56 |
| SGL | 70.32±0.74 | 74.14±0.82 | 68.98±0.62 | 73.72±0.77 |
| GPPT | 72.14±2.64 | 74.80±1.76 | 69.52±0.80 | 72.50±0.93 |
| SGL-PT | 73.84±1.89 | 77.26±0.98 | 70.08±1.71 | 75.50±0.90 |
| MOLT-4 | P388 | |||
| 1-shot | 3-shot | 1-shot | 3-shot | |
| GraphMAE | 52.80±1.85 | 52.90±1.64 | 52.45±2.92 | 53.39±1.41 |
| SGL | 52.67±2.31 | 52.75±2.80 | 53.96±3.37 | 54.66±1.53 |
| GPPT | 50.08±0.68 | 50.77±1.23 | 50.37±2.40 | 52.52±3.77 |
| SGL-PT | 52.97±2.10 | 54.68±1.72 | 54.89±1.87 | 55.33±2.13 |
| DD | MUTAG | |||
| 1-shot | 3-shot | 1-shot | 3-shot | |
| 52.30±2.51 | 53.39±2.87 | 57.92±3.43 | 62.39±1.53 | |
| 54.11±2.09 | 58.31±2.40 | 65.46±5.02 | 69.51±2.56 | |
| 55.66±2.20 | 58.30±1.26 | 59.43±3.40 | 64.07±2.38 | |
| 57.15±1.53 | 59.49±1.25 | 68.58±5.49 | 74.47±2.32 | |
| Method | DD | MUTAG | COLLAB |
| Params(K) | Params(K) | Params(K) | |
| GIN | 820 | 10 | 296 |
| GPPT | 258 | 1 | 64 |
| GPF | 1.10 | 0.07 | 1.10 |
| ProG | 1.87 | 0.13 | 4.67 |
| GraphPrompt | 0.50 | 0.03 | 0.25 |
| SGL-PT | 1 | 0.06 | 0.75 |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Graph Neural Networks · Topic Modeling · Machine Learning in Materials Science
SGL-PT: A Strong Graph Learner with Graph Prompt Tuning
Yun Zhu, Jianhao Guo, Siliang Tang,
Abstract
Recently, much exertion has been paid to design graph self-supervised methods to obtain generalized pre-trained models, and adapt pre-trained models onto downstream tasks through standard fine-tuning. But the gap between pretext and downstream tasks can limit pre-trained models’ potential, leading to negative transfer. Meanwhile, prompt tuning has seen emerging success in natural language processing (NLP) by aligning pre-training and fine-tuning with consistent training objectives. In this paper, we identify the challenges for graph prompt tuning: The first is the lack of a strong and universal pre-training task across sundry pre-training methods in graph domain. Such a task should be easily emulated by downstream tasks, akin to the Masked Language Modeling (MLM) in the NLP domain. The second challenge lies in the difficulty of designing a consistent training objective for both pre-training and downstream tasks due to the inherent abstraction of graph data. To overcome above obstacles, we propose a novel framework named SGL-PT which follows the learning strategy “Pre-train, Prompt, and Predict”. Specifically, we raise a strong and universal pre-training task coined as SGL that acquires the complementary merits of generative and contrastive self-supervised graph learning. And motivated by prompt design in NLP, we reformulate the downstream task as maksed node prediction by designing a novel verbalizer-free prompting function, resulting in unifying the objectives of pre-text and downstream tasks. Empirical results show that our pre-training method surpasses other baselines under unsupervised setting, and our prompt tuning method can significantly facilitate models on biological datasets over standard fine-tuning and other graph prompt methods.
Introduction
Graph self-supervised Learning methods (Hou et al. 2022; You et al. 2020) have emerged to create generalized pre-trained models without labels. To adapt these models for downstream tasks, a “pre-train, fine-tune” approach is often used. However, there exists a gap between pre-training and downstream tasks, that hinders knowledge transfer and potentially leads to negative transfer (Zhang et al. 2022) or overfitting (Zhu et al. 2021). For instance, edge prediction (Hu et al. 2020) pre-trained on local node relations may struggle to generalize to graph classification tasks requiring global relations. This shift in representation can lead to negative transfer (Sun et al. 2022).
A feasible solution for mitigating this gap is to unify the pre-training and fine-tuning models with consistent training objectives by prompt technique (Liu et al. 2022). Though already mature in NLP domain, graph prompt tuning is still under exploration. We identify the main challenges of graph prompt tuning: (1) The need for a strong and universal pre-training task: graph prompting method requires the pre-training task to capture rich information (e.g., intra-data and inter-data relations). Existing pre-training methods fail to learn rich information in graph domain, because most of them only focus on learning local relations (Hou et al. 2022) or global relations (You et al. 2020) while neglecting the inter-dependency of both. Furthermore, graph prompt methods mandate a pre-training task that can be readily emulated, ensuring smooth integration into downstream tasks. However, a pre-training task with these desired characteristics is currently absent. (2) The difficulty of reformulating the downstream task in the same format of the pre-training task: unlike the cloze template in NLP domain, how to design meaningful prompt templates and verbalizers for graphs remains an open problem due to the inherent abstraction of graph data. Recently, some works made first attempts on graph prompt tuning. GPPT (Sun et al. 2022) and GraphPrompt (Liu et al. 2023) unify the pretext and downstream task in a similar format (i.e., edge prediction task). However, edge prediction is a trivial binary classification task, which can not capture rich inter-data information and lose its generality (not solving challenge 1). GPF (Fang et al. 2022) and Prompt Graph (ProG) (Sun et al. 2023) introduce a learnable prompt feature and a learnable prompt graph into the input space as plug-ins, which can be incorporated into any pre-trained models. These methods are more likely new transformation tricks added to input space, while they do not holistically align the training objectives between pretext and downstream tasks, hence not fully addressing challenge 2. In short, the challenges we propose are waiting to be solved.
To overcome the above obstacles, we propose a novel graph learning framework coined as SGL-PT that follows the “Pre-train, Prompt and Predict” strategy. Specifically, for challenge 1, we design a strong and universal graph self-supervised method named Strong Graph Learner (SGL). It combines generative and contrastive models for complementary strengths: the generative method has better robustness and can ratiocinate the characters of nodes according to neighbor through learning local (intra-data) relations, but lacks discriminative representations; The contrastive method focuses on learning qualified global (inter-data) representations through instance discrimination but may lose detailed information on individual graphs. We effectively combine these two self-supervised methods through asymmetric design and a dynamic queue to obtain a strong and universal self-supervised graph learner. Besides, this pre-training task can be mimicked by the downstream task painlessly. For challenge 2, we design a novel prompting function that introduces a masked super node into individual graphs and reformulates the downstream graph classification as masked node prediction. We realize verbalizer-free class mapping by introducing supervised graph prototypical contrastive learning to establish the mapping between reconstructed features and semantic labels. In this way, we unify the training objectives between pretext and downstream tasks. Results show that our pre-training method surpasses other strong generative methods (Hou et al. 2022) and contrastive baselines (Xu et al. 2021a). And our prompt tuning strategy can greatly facilitate models on biological datasets compared to standard fine-tuning and other graph prompt methods.
Our contributions are summarized as follows:
- •
We identify the main challenges for graph prompting. And we propose a novel framework following “Pre-train, Prompt, and Predict” strategy to solve the challenges.
- •
We propose a strong and universal graph self-supervised method SGL unifying generative and contrastive merits through the asymmetric design. And we unify pre-training and fine-tuning by designing a novel verbalizer-free prompting function.
- •
Empirical results show that our method surpasses other baselines under the unsupervised setting, and our prompt tuning method can significantly facilitate models on biological datasets compared to fine-tuning methods.
Related Works
Graph Self-supervised Learning
Graph self-supervised methods can be classified into three categories: Predictive, Generative and Contrastive (Wu et al. 2021). Predictive method self-generates labels by statistical analysis and designs prediction-based pre-training tasks on the generated labels (Jin et al. 2020).
Generative method focuses on learning local (intra-data) relations based on pretext tasks such as feature/edge reconstruction (Kipf and Welling 2016; Hou et al. 2022). Recently, the generative-based masked autoencoders (e.g., GraphMAE (Hou et al. 2022)) show their superior performance across different tasks. Masked autoencoders learn hidden representations by masking partial node features to obtain high-level representations and then reconstructing these masked node features through decoders. These methods focus on the local relations between nodes, which may be incompetent for learning discriminative graph representations and deficient for graph classification task.
Contrastive method maximizes the agreement between positive data pairs and pushes away the negative data pairs in the representation space to learn the global (inter-data) relations between graphs (You et al. 2020; Zhu et al. 2022), but may lose the detailed information within a single graph.
In our work, we combine the generative and contrastive methods through asymmetric design and a dynamic queue for complementary merits to obtain a strong and universal pre-training method.
Prompt-based Learning
The training strategy “Pre-train and Fine-tune” is widely used to adapt pre-trained models onto specific downstream tasks. However, this strategy ignores the inherent representation gap between pre-training and downstream tasks and leads to poor performance on few-shot problems.
Prompt-based learning (Liu et al. 2022) is a technique arising from NLP to narrow the gap between pre-training and downstream tasks. It reformulates the downstream task to match the format of the pre-training task. This involves designing a prompt template that converts the task into a masked word prediction task. For instance, in sentiment analysis, the label prediction task can be transformed into a masked word prediction by the pre-defined template like “[X]. It is a [MASK] movie” for the input sentence [X]=“I love this movie.”, aligning with the pre-training mask language model. The second aspect is the verbalizer design, which maps the output word [MASK] to a specific label. For example, words like ‘good, great’ may be associated with the label ‘+’, while ‘bad, terrible’ are associated with the label ‘-’. The dot product of [MASK] for these tokens provides confidence in assigning the sentence a label.
While prompt-based learning is well-established in NLP, it’s relatively new in the graph domain. Existing approaches like GPPT (Sun et al. 2022) and GraphPrompt (Liu et al. 2023) rely on edge prediction as a pre-training task and reformulate the downstream task as edge prediction. But edge prediction is a trivial binary classification task and focuses on learning local (intra-data) relations which will be incompetent for graph classification. GPF (Fang et al. 2022) introduces a universal prompt feature for various pre-trained models, and ProG (Sun et al. 2023) extends GPF and introduces a universal graph prompt (i.e., a set of prompt features), but they do not unify the training objectives of the pertaining and downstream tasks, limiting its potential.
In contrast to prior works (Sun et al. 2022; Liu et al. 2023), we address two challenges simultaneously. By considering graph-level task characteristics and the prompt template’s dependence on the pretext task, we redesign the pre-training method and graph prompt template to mitigate representation gaps. In the following section, we will demonstrate how we overcome these obstacles.
Method
Preliminaries
This work focuses on graph classification task on a series of graphs . denotes a single graph, where is the raw node features, is the adjacent matrix, and is the graph label. For notations, represents node number of graph , represents the dimension of raw features, and means there exists an edge between node and node in graph , otherwise 0. represent GNN encoder and represents decoder. , and represent projection head and readout function respectively. Under unsupervised training setting of this work, the label information is unavailable for each graph during pre-training.
SGL: Strong Graph Learner for Pre-training
In this part, we will introduce our proposed pre-training strategy depicted in Figure 1, which consists of two branches: local and global. For local branch, it focuses on learning intra-data relations via a graph masked autoencoder. For global branch, it empowers the pre-training model with instance-wise discriminative ability by graph contrastive learning. Then we propose a non-trivial solution to integrate these two branches effectively. Algorithm 1 in Appendix A gives more details about the procedure.
Local Branch
In this branch, firstly we mask partial nodes’ features and then obtain the high-level node representations through online encoder . Then, we re-mask the node representations and reconstruct masked node features through the decoder , this is because GraphMAE (Hou et al. 2022) empirically found that re-masking the node representations for decoding will bring performance improvement. Finally, we use scaled cosine error as the criterion. Our local branch loss is formulated as:
[TABLE]
where represents masked nodes’ set and denotes reconstructed masked node features. And is a scaling factor that adjusts the contribution of each sample. This loss is averaged on the masked nodes’ set . In order to keep the notation uncluttered, we consider the batch size as 1 here.
In this way, local branch has better robustness and can ratiocinate the characters of nodes according to neighbors. However, this method mainly focuses on learning local relations on individual graph, which is incompetent to learn discriminative representations for graph classification.
Global Branch
To complement the insufficiency of the local branch and make the encoder capture global (inter-data) discriminative information among graphs, we appeal to graph contrastive learning. Even though data augmentation is an essential part of contrastive learning (You et al. 2020; Shi et al. 2020; Zhu et al. 2022), we empirically found simple augmentations like node feature masking and edge removing are good enough to improve the representation ability of the proposed Strong Graph Learner.
The processes of the global branch can be concluded as: firstly we obtain node representations by online and target encoders. Then through a readout function (e.g., mean pooling), we obtain global representations as:
[TABLE]
where represents the -th node representation in graph . Following that, a projection head is added on top of the encoder to map augmented representations to another latent space where the contrastive loss is calculated. Finally, we will contrast these representations through NT-Xent (the normalized temperature-scaled cross-entropy loss (Chen et al. 2020)). The sample with its augmented view is considered as a positive pair, and others are considered as negative pairs. The formula of contrastive loss follows:
[TABLE]
where computes the similarity score between and , is temperature parameter, and denotes the size of mini-batch.
How to Integrate Local And Global Branches?
Due to consideration of efficiency and performance, the small batch sizes (e.g., 8, 16) are used for training local branch, which is insufficient for effective contrastive learning (Chen et al. 2020). Thus, direct integration of these branches is not feasible. To address this, we employ a dynamic queue to incorporate more negative samples and enhance the integration of local and global branches. This dynamic queue holds the first in first out property and we use the target representations to update this dynamic queue . In order to keep the consistency of representations in the dynamic queue to the utmost, we use exponential moving average to update the target encoder and projection head . Formally, the parameters in is updated by . Here is a momentum coefficient that controls the smoothness of evolving target parameters, we use a relatively large momentum (e.g., 0.999) in our experiments empirically. With dynamic queue , the loss in global branch:
[TABLE]
where denotes the size of the dynamic queue. represents the sampled representations in the dynamic queue. The total loss of the global branch is averaged over batch samples (i.e., ).
In summary, the overall pre-training loss is defined as
[TABLE]
where controls the weight of local loss. For most cases, we set as 0.5 which means local and global loss contributes equally.
Verbalizer-free Graph Prompt Tuning
After obtaining the pre-trained model from SGL , we propose a novel graph prompt tuning technique that mitigates the representation gap between pre-trained model and downstream tasks. Through prompt addition, we reformulate the downstream task in the same format as the pre-training task. And with the design of verbalizer-free prompt answer, we get rid of the verbalizer which is hard to design in graph domain. The process of graph prompt tuning is shown in Figure 2. And Algorithm 2 in Appendix A offers procedures.
Prompt Addition: Re-formulate downstream task
Prompt-based learning methods mitigate the representation gap by reformulating the downstream task to the same format as the pre-training task, however, it is non-trivial to do so in graph domain due to the inherent abstraction of graph data. To solve the problem, we introduce a masked super node, which connects to all nodes in the graph and therefore has a global receptive field. Thus the representation of the super node can be seen as the representation of the whole graph. To reconstruct the features of the masked super node, we transform the original classification task into masked feature reconstruction, which corresponds to the task during pre-training. This idea is motivated by the prompt tuning in NLP domain, which adds a template with a task slot (masked word) for input sentence and predicts the masked word of the slot (Liu et al. 2022), remaining different to previous graph prompt works (Sun et al. 2022).
Prompt Answer: Verbalizer-free Class Mapping
We have reformulated the downstream task as masked node feature reconstruction. Since there are no semantic and representative input features related to labels for the masked super nodes, it is hard to use verbalizer to establish mappings between reconstructed features and their semantic labels.
In this work, we get rid of the verbalizer by introducing the supervised prototypical contrastive learning (SPCL) (Li et al. 2020; Cui et al. 2022) for class mapping. Specifically, the prototypes represent essential features corresponding to labels. As depicted in Figure 2, we will obtain class prototypes by SPCL and use these prototypes as semantic tokens related to labels. The supervised prototypical contrastive loss is given by:
[TABLE]
where the first part is instance-instance loss which draws intra-class pairs close and pushes inter-class pairs away in the representation space. And the second part represents instance-prototype loss which makes the similarity scores between instances of class and prototype larger than other prototypes. Class prototypes are learnable vectors, which are updated by the second part loss. These two parts both use NT-Xent which is the same as the global branch.
Prompt Tuning
Through prompt engineering (i.e., prompt addition and prompt answer), we reformulate the downstream task in the same format as pre-training method, which means the original classification task is transformed into reconstructing the super node’s representations. In order to keep the training objective consistent with pre-training and avoid catastrophic forgetting of the pre-trained knowledge, we use as an auxiliary loss with a low masking rate during prompt tuning. The overall loss of prompt tuning is defined as:
[TABLE]
where controls the contribution of each component loss. We set as 0.1 for all datasets to focus on learning class prototypes used for classification.
In this way, we implement verbalize-free graph prompt tuning. This loss holds a similar format with .
Prediction
Similar to prompt tuning, we add a masked super node in the original graph and hope to reconstruct its corresponding class prototype. Specifically, by comparing the representation of the super node with each class prototype , we can get the predicted class probability:
[TABLE]
We will choose the highest score as our predicted class:
[TABLE]
Experiments
In this section, we will introduce the datasets and experimental setups of graph classification that we used. Then we evaluate the performance of the proposed self-supervised pre-training method SGL. Thirdly, we prove the effectiveness of our prompt tuning framework SGL-PT compared with standard fine-tuning and other graph prompt methods under semi-supervised and few-shot settings. After that, we conduct ablation study to prove the necessity of the design. Additionally, we perform sensitivity analysis on crucial hyperparameters (e.g., dynamic queue size and loss coefficient ) in Appendix C. We substantiate the robustness and generality of our pre-training method with supplementary experiments (e.g., molecule property prediction (over 2 million graphs) and node classification) in Appendix C, showcasing SOTA performance. Furthermore, we empirically validate the effectiveness of our prompt design in few-shot node-level tasks within Appendix C. Moreover, we assess parameter efficiency among various prompt methods to show SGL-PT’s efficiency in Appendix D.
Datasets
We perform experiments of graph-level tasks on widely used 12 datasets from TUDataset (Morris et al. 2020). The statistics of the used datasets can be found in Table 1. They can be classified into two categories: biological and social networks. More details of these datasets are in Appendix B.
Evaluation of Proposed Pre-trained Method
In this section, we will evaluate our pre-training method SGL under unsupervised setting.
Baselines
Our baselines mainly consist of three categories: supervised methods, graph kernel methods, and other unsupervised methods. Specifically, we compare with three supervised baselines: GCN (Kipf and Welling 2017), GIN (Xu et al. 2019) and DiffPool (Ying et al. 2018). The SOTA graph kernel methods include graphlet kernel (GL) (Shervashidze et al. 2009), Weisfeiler-Lehman sub-tree kernel (WL) (Shervashidze et al. 2011) and deep graph kernel (DGK) (Yanardag and Vishwanathan 2015). We also compare with unsupervised representation learning methods including sub2vec (Adhikari et al. 2018), graph2vec (Narayanan et al. 2017), EdgePred (Hu et al. 2020), InfoGraph (Sun et al. 2019), GraphCL (You et al. 2020), JOAO (You et al. 2021), SimGRACE (Xia et al. 2022), MVGRL (Hassani and Khasahmadi 2020), InfoGCL (Xu et al. 2021a) and GraphMAE (Hou et al. 2022). The introduction of baselines can be found in Appendix B.
Experiment Setup
The quality of the pre-trained graph encoder is then evaluated by the linear separability of the final representations. Namely, an additional trainable linear classifier is built on top of the frozen encoder following (Hou et al. 2022; You et al. 2020). We adopt GIN (Xu et al. 2019) as our encoder and decoder with the default setting in (Hou et al. 2022). The complete hyper-parameters and more details are listed in Appendix B.
Analysis
The results are listed in Table 2 and Table 3, from which we can draw the conclusions:
SGL outperforms kernel methods on all datasets by a large margin (e.g., 3.2% absolute improvement over the SOTA graph kernel method DGK on PROTEINS). And our method even outperforms the best supervised model (i.e., GIN) on eight out of twelve datasets, which closes the gap between unsupervised methods and supervised methods.
Compared with other unsupervised methods, SGL achieves SOTA results except on MUTAG. We suspect this dataset is too small to inspire the full potential of our pre-training method. On larger scale datasets in Table 3, our method surpasses other methods by considerable margins (e.g., around 1.3% absolute improvement over other methods on NCI-H23 dataset). The strong results show the superiority of our proposed pre-training method SGL.
Evaluation of Proposed Prompt-tuning
In this section, we aim to investigate the effectiveness of our proposed prompt method SGL-PT. And we will conduct experiments under different settings (i.e., semi-supervised, few-shot settings) to achieve this goal. Our prompt method can be applied to other pre-training methods with a little modifications which can be found in Appendix C.
Semi-supervised Setting
Experiments follow a semi-supervised setting with pre-training & fine-tuning (Chen et al. 2020; You et al. 2020). We don’t fix the pre-trained model and tune all parameters for downstream tasks. In more limited source scenarios like the next section (few-shot setting), we will freeze the pre-trained model and only train additional parameters for downstream tasks (e.g., classifier).
Baselines
To thoroughly investigate the effectiveness of the proposed SGL-PT, we compare it with methods of different training strategies.
Firstly all the baselines (except ‘No pre-train.’) are pre-trained with pretext tasks. Then for fine-tuning methods, we obtain the pre-trained models in advance and fine-tune them with a linear classifier on downstream labeled data.
For GPPT, EdgePred serves as the pre-training method, with downstream tasks reformulated into edge prediction tasks following their prompt design. GraphPrompt, akin to GPPT, employs EdgePred for pre-training and adapts downstream tasks into edge prediction, employing a simplified prompt template featuring weighted summation readout. employs SGL as the pre-trained model, integrating a learnable graph prompt feature onto node attributes. In the case of , SGL is also the pre-trained model, incorporating a learnable prompt graph into the original graph. In the context of SGL-PT, SGL serves as the pre-trained model, and we reframe downstream tasks in a similar format to the pre-training task. Furthermore, in order to demonstrate the superiority our prompt method is not only dependent on a superior pre-train model, we replace the EdgePred with SGL in GPPT, GraphPrompt which are coined as and respectively. We use grid search on important hyper-parameters to get the best performance.
Experimental Setup
We ensure a fair comparison by using the same model configuration for all methods. The detailed settings and hyper-parameters are in Appendix B.
Analysis
Table 4 summarizes the results of different training strategies, and we can get the following information:
For fine-tuning-based methods, the order of performance follows “EdgePred GraphCL GraphMAE SGL”, which is similar to the results in previous unsupervised learning. Our method SGL surpasses other fine-tuning-based methods, which again testifies the effectiveness of our pre-training method SGL. It is worth noting that EdgePred even cannot outperform ‘No pre-train.’ on some datasets, which indicates EdgePred method triggers negative transfer.
SGL-PT outperforms all fine-tuning-based methods and surpasses SGL by a large margin (around 3% average improvement on all datasets), which proves the existence of representation gap between pre-training and downstream tasks and the urgency of minimizing such gap. It also shows that SGL-PT outperforms GPPT and GraphPrompt by a large margin (around 6% average improvement on all datasets). The reason GPPT and GraphPrompt perform poorly is that they only utilize limited learned knowledge of the pre-trained model. And SGL-PT surpasses and by a considerable margin (around 5% average improvement on all datasets), which proves the effectiveness of unifying the training objectives.
We can find , , and even perform worse than SGL due to catastrophic forgetting from inconsistent objectives on some datasets (e.g., For DD dataset, , and lag 2.8%, 0.8%, 6.2% and 1.8% behind respectively). This experiment serves as proof that the effectiveness of our prompt tuning method is not dependent on a superior pre-train model, what matters is to establish consistent training objectives that align with both pre-train and downstream tasks.
Few-shot Setting
In many real-world scenarios, it is challenging to collect and label a large amount of data. Few-shot learning is a well-known case of low-resource scenarios. We conduct experiments in such a setting to prove the effectiveness of our method in low-resource scenarios. Experiments on more datasets can be found in Appendix C.
Experimental Setup
In this section, we evaluate different training strategies with more limited supervision in a few-shot setting. This entails having only a small number of labeled graphs per class, denoted as -shot classification. We perform experiments with 1-shot and 3-shot graph classification to evaluate all methods. The model setup remains consistent with the previous section. Additionally, for prompt methods, we freeze the pre-trained model and exclusively train supplementary parameters for downstream tasks.
Analysis
Table 5 summarizes the results of different training strategies, and we can obtain the following results:
In few-shot setting, prompt methods can achieve better performance than standard fine-tuning-based methods (e.g., GPPT outperforms EdgePred, SGL-PT outperforms SGL.) which proves the essential of unifying the pre-training and downstream tasks.
And SGL-PT can still outperform other methods which proves the effectiveness of our prompt design even in low-resource scenarios. (e.g., SGL-PT surpasses other methods by over 7 % absolute improvement on the MUTAG of 5-shot.)
Even with a stronger pre-trained model, , , and still perform moderately on most datasets, highlighting the importance of unifying the training objectives of pre-training and downstream tasks to fully exploit learned knowledge. These findings underscore the necessity to solve the challenges we proposed simultaneously.
Ablation Study
To prove the effectiveness of the design of our pre-training method SGL, we conduct ablation experiments that mask different components under the same model configuration. SGL contains two branches (i.e., local and global), and we mask them separately. ‘w/o global branch’ means that we only use the local branch. ‘w/o local branch’ means that we only use the contrastive learning method. And ‘w/o dynamic queue’ represents that we do not use the dynamic queue to provide adequate negative samples. From Table 6, the results of ‘w/o dynamic queue’ lags far behind SGL, which demonstrates the dynamic queue is essential to integrate these two branches better. Disjunct single branches also lag behind SGL, this means through efficient combination, we acquire complementary strengths of both generative (local branch) and contrastive (global branch) methods.
Conclusion
In this work, we identify the main challenges for graph prompt tuning. To solve them, we propose a “Pre-train, prompt and predict” framework coined as SGL-PT. This framework unifies the pre-training and downstream task in the same format and minimizes the training objective gap. Specifically, we design a strong and universal pre-training task that acquires the complementary strengths of generative and contrastive methods. Based on this pre-training method, we design a novel verbalizer-free prompting function to reformulate the downstream task in the same format as our pre-training method. Empirical results show that our pre-training method surpasses other baselines under the unsupervised setting, and our prompt tuning method can greatly facilitate pre-trained models compared to standard fine-tuning methods and other graph prompt tuning techniques.
A. Algorithm
Algorithm of SGL
The overall processes of SGL can be described as Algorithm 1. The input graphs will be fed into two branches (i.e., local and global). In the local branch, we will mask partial node features. Then through the online encoder , we will obtain fragmentary representations. Following (Hou et al. 2022), we will re-mask node representations. And finally through a decoder, we restore the input features on masked nodes. The reconstructed loss is computed by Equation 1.Meanwhile, in the global branch, we will augment input graphs and feed them into target encoder , readout function and target projection head to obtain global representations , and we will make use of node representation in local branch to obtain other view representations through readout function and online projection head . Finally, through Equation 5, we obtain the global loss in the global branch. And the total loss is computed by a weighted sum of . The online encoder, projection head and decoder are updated by gradient descent. But the target encoder and projection head are updated by momentum update.
Algorithm of Graph Prompt Tuning
The overall process of graph prompt tuning SGL-PT can be concluded in Algorithm 2. Firstly, we will add a masked super node on each downstream graph and reformulate the downstream task as masked node prediction. We use prototype contrastive to get rid of the verbalizer, and additionally we will mask partial node features and calculate reconstruction loss to avoid catastrophic forgetting. Lastly, the prototype contrastive loss and reconstruction loss are linearly combined through coefficient .
B. Experimental Details
Datasets
In this subsection, we give a more detailed description of the datasets used in the main paper. More descriptions can be found in (Yan et al. 2008; Yanardag and Vishwanathan 2015).
For bioinformatics datasets:
-
MUTAG: This dataset consists of 188 mutagenic aromatic and heteroaromatic nitro compounds. Each compound is represented as a graph, with nodes corresponding to atoms and edges indicating bonds between them. There are 7 discrete labels associated with the mutagenic activity of the compounds.
-
PROTEINS: In this dataset, nodes represent secondary structure elements (SSEs) in protein structures. An edge is established between two nodes if the corresponding SSEs are neighboring in the amino-acid sequence or in 3D space. The dataset contains protein structures with 3 discrete labels, representing helix, sheet, or turn.
-
DD: The dataset comprises 1178 protein structures, each depicted as a graph wherein amino acids serve as nodes, and edges connect two nodes if their spatial separation is within 6 Angstroms. The objective of this predictive endeavor is the categorization of the protein structures as either enzymatic or non-enzymatic entities. Importantly, it’s worth noting that nodes are uniformly labeled across all datasets.
-
NCI1: Derived from the National Cancer Institute (NCI), this dataset is a subset of chemoinformatics datasets. It contains chemical compounds that have been screened for their ability to suppress or inhibit the growth of a panel of human tumor cell lines. NCI1 has 37 discrete labels, representing different outcomes of the cell line growth inhibition assay.
-
NCI-H23, MOLT-4 and P388: They offer insights into the biological activities of small molecules, including bioassay records for anticancer screenings across various cancer cell lines, specifically categorized as ’Non-Small Cell Lung’, ’Leukemia’ and ’Leukemia’ respectively.
These datasets offer a diverse range of chemical and biological contexts, providing challenges for graph classification tasks with varying numbers of labels and graph structures.
For social network datasets:
-
IMDB-BINARY and IMDB-MULTI: These datasets are based on movie collaboration networks. Each graph represents the ego-network of an actor or actress, where nodes correspond to individuals and an edge is present between two nodes if they have collaborated in the same movie. IMDB-BINARY is a binary classification task where the goal is to classify each graph into a specific movie genre. IMDB-MULTI is a multi-class classification task with the objective of classifying the graphs into various movie genres.
-
REDDIT-BINARY and REDDIT-M12K: These datasets capture online discussion threads. Nodes in these graphs represent users participating in discussions, and edges are formed when users interact by responding to each other’s comments. REDDIT-BINARY involves binary classification, aiming to classify graphs into specific community or subreddit labels. REDDIT-M12K extends this to multi-class classification consisting of 11 different subreddits, namely, ”AskReddit, AdviceAnimals, atheism, aww, IAmA, mildlyinteresting, Showerthoughts, videos, todayilearned, worldnews, TrollXChromosomes.” And the goal is to predict which subreddit a given discussion graph belongs to.
-
COLLAB: Derived from scientific collaboration data, this dataset consists of ego-networks representing researchers in different fields. Each graph focuses on researchers from a specific scientific field, such as High Energy Physics, Condensed Matter Physics, or Astrophysics. The classification task involves assigning each graph to the corresponding scientific field.
These datasets provide diverse scenarios for evaluating graph classification algorithms, with each graph representing a unique context of collaboration or interaction, and the classification tasks focusing on predicting genres, communities, or scientific fields.
It has to note that social networks do not contain raw node attributes, we use node degrees as their attributes following (Hou et al. 2022). As for biological graphs, we use their categorical node attributes.
Baselines
In this part, we will introduce the baselines used in our experiments:
- •
Edge prediction (EdgePred) (Hu et al. 2020) treat existing links as training signals. Its training objective is binary cross-entropy loss.
- •
Unsupervised And Semi-Supervised Graph-level Representation Learning via Mutual Information Maximization (InfoGraph) (Sun et al. 2019) takes a pair of global representation and patch representation as input and employs a discriminator to determine if they belong to the same graph based on Deep InfoMax (Hjelm et al. 2018). This process generates all possible positive and negative samples in a batch-wise manner.
- •
Graph Contrastive Learning with Augmentations (GraphCL) (You et al. 2020) proposes various augmentation techniques for graph data and investigates their impacts on different types of datasets. Firstly, the input graph will be fed into two random augmentation functions to generate two graph views, then these augmented graphs will be fed into GNN encoder with readout function to obtain graph representations. Finally, these graph representations will be used to contrast with InfoNCE loss (Oord, Li, and Vinyals 2018).
- •
Graph Contrastive Learning Automated (JOAO) (You et al. 2021) introduces the concept of joint augmentation optimization, which formulates a bi-level optimization problem by simultaneously optimizing the selection of augmentations and the contrastive objective.
- •
A Simple Framework for Graph Contrastive Learning without Data Augmentation (SimGRACE) (Xia et al. 2022) eliminates data augmentation while introducing encoder perturbations to generate distinct views for graph contrastive learning.
- •
Contrastive Multi-View Representation Learning on Graphs (MVGRL) (Hassani and Khasahmadi 2020) utilizes the information of multi-views for contrasting. Firstly, it will use the edge diffusion function to generate an augmented graph. And asymmetric encoders will be applied on the original graph and diffusion graph to acquire node embeddings. Next, a readout function is employed to derive graph-level representations. Original node representations and augmented graph-level representation are regarded positive pairs. Additionally, the augmented node representations and original graph-level representation are also considered as positive pairs. The negative pairs are constructed through random shuffling.
- •
Information-Aware Graph Contrastive Learning (InfoGCL) (Xu et al. 2021a) suggests minimizing the mutual information between contrastive parts while preserving task-relevant information within both the individual module and the overall framework. This approach aims to minimize information loss during graph representation learning, following the Information Bottleneck principle (Alemi et al. 2016).
- •
Self-Supervised Masked Graph Autoencoders (GraphMAE) (Hou et al. 2022) is a masked autoencoder. It will mask partial input node attributes firstly and then the encoder will compress the masked graph into latent space, finally a decoder aims to reconstruct the masked attributes.
Evaluation Protocol
Unsupervised Representation Learning
For small-scale datasets, we follow (Hou et al. 2022; You et al. 2020) to assess our SGL pre-training method. After pre-training, we keep the model fixed to generate graph-level representations. These representations are then fed into a downstream LIB-SVM (Chang and Lin 2011) classifier on small-scale datasets. It’s important to note that all self-supervised methods are trained using unsupervised data. The evaluation of the pre-trained model involves training only a classifier with supervised data from the downstream task. Reported results include the mean accuracy from 10-fold cross-validation, with standard deviation after 5 runs, on small-scale datasets.
For larger-scale datasets (NCI-H23, MOLT-4, P388, REDDIT-M12K), due to LIB-SVM’s convergence issues, we adopt a one-layer MLP as the downstream classifier. Results are reported as the mean performance from 5-fold cross-validation, using accuracy for REDDIT-M12K and ROC-AUC (Davis and Goadrich 2006) for others.
Semi-supervised Setting
To assess our proposed prompt method, we carry out experiments in both semi-supervised and few-shot settings. In the semi-supervised setup, we fix the label rate at 10%, indicating that only 10% of the training data contains labels. We perform parameter tuning for both encoders and additional downstream components (e.g., classifier) to optimize performance on downstream tasks.
Few-shot Setting
In contrast to the previous section, we introduce even scarcer supervised signals in this scenario to simulate low-resource conditions. Specifically, each class comprises only one or three instances for training data, known as 1-shot and 3-shot graph classification. To address concerns of overfitting and parameter efficiency, we solely fine-tune the downstream parameters in the few-shot setting.
Hyper-parameters
All hyper-parameters used in unsupervised learning are listed in Table 7. The coefficients in Equation 5 is 0.5 for most datasets except COLLAB and NCI1 (0.9). For the local branch, we adopt similar hyper-parameters in (Hou et al. 2022). And for the global branch, we do not use augmentation in relatively large datasets (e.g., NCI1 and ZINC). For other small datasets, we use grid search to obtain the augmentation ratios. For the training hyper-parameters (e.g., batch size, epochs and etc), we adopt similar settings in (Hou et al. 2022).
As for the hyper-parameters in the prompt tuning, Table 9 gives you the details. During prompt tuning, we still mask partial nodes’ features at a low rate mask (10%) to avoid the catastrophic forgetting of the pre-training knowledge and over-fitting. For all methods, we use the same hyper-parameters for a fair comparison. You can increase the number of training epochs, the downstream performance may be boosted further. All the pre-trained models are finetuned under the same setting. Some hyper-parameters are searched with grid search. The learning rate is searched in 0.01, 0.001, 0.0001 and the readout function is in mean, max, sum. And we use the same batch size.
Computer Infrastructures Specifications
For hardware, most experiments are conducted on a computer server with four GeForce RTX 2080Ti GPUs with 11GB memory and 48 Intel(R) Xeon(R) CPU E5-2678 v3 @ 2.50GHz. Besides, our models are implemented by Pytorch Geometric 2.0.4 (Fey and Lenssen 2019), DGL 0.9.1 (Wang et al. 2019) and Pytorch 1.11.0 (Paszke et al. 2019). All the datasets used in our work are available in DGL and PyTorch Geometric libraries. For molecular property prediction, our implementation is based on the code in https://github.com/snap-stanford/pretrain-gnns with Pytorch Geometric 2.0.4 on GeForce RTX 3090.
C. Additional Experiments
Due to space constraints in the main content, this section will encompass additional experiments involving molecule property prediction and node classification. These endeavors aim to demonstrate the robustness and versatility of our method. Furthermore, we will delve into an analysis of important hyperparameters, specifically focusing on the dynamic queue size () and loss coefficient (). Lastly, we will evaluate our graph prompting method on more datasets and apply it to other pre-training methods.
Molecule Property Prediction
Besides the single-label graph classification task in the main content, we also evaluate our pre-training method on another graph-level task (i.e., molecular property prediction) to predict chemical molecule properties. In this experiment, a larger molecule dataset is employed to pre-train the model, followed by fine-tuning on smaller downstream datasets. This experiment serves to assess the ability of our pre-training method to generalize across different distributions, showcasing its potential for transfer learning.
Datasets
ZINC dataset is used for pre-training which consists of 2 million unlabeled molecules sampled from ZINC15 (Sterling and Irwin 2015). Other eight datasets are downstream datasets contained in MoleculeNet (Wu et al. 2018), input node features are atom number and chirality tag, and edge features are bond type and direction. Scaffold-split is used to splits graphs into train/val/test sets that mimic real-world use cases. The statistics of these datasets can be found in Table 10. These datasets are widely used for evaluating the transferability of pre-training methods (You et al. 2020; Sun et al. 2019; You et al. 2021).
Baselines
We employ several baselines containing no pre-trained GIN (i.e., directly fine-tune on downstream dataset without self-supervised pre-training), as well as GraphCL (You et al. 2020), JOAO (You et al. 2021), GraphLoG (Xu et al. 2021b), GraphMAE (Hou et al. 2022) and three different pre-training strategies (i.e., Infomax, AttrMasking and ContextPred) proposed in (Hu et al. 2020) which incorporates the domain knowledge heuristically that correlates with the specific downstream datasets. All baselines (except No-pretrain) will firstly be pre-trained on ZINC datasets and then adopt them on downstream datasets.
Experimental Setup
We evaluate SGL under the transfer learning setting as follows (Hou et al. 2022; You et al. 2020). Firstly, we pre-train the model with 2 million unlabeled molecules sampled from the ZINC15(Sterling and Irwin 2015), and then we finetune the pre-trained models in 8 multi-label multi-class benchmark datasets contained in MoleculeNet (Wu et al. 2018) with scaffold-split. In our experiments, in the local branch, we will reconstruct node features and not reconstruct edge features like (Hou et al. 2022). In the global branch, we set dynamic queue size , temperature and momentum as 4096, 0.05 and 0.999 separately. And for simplicity, we do not use additional augmentation here. For evaluation, we run experiments 10 times and report the mean and standard deviation of ROC-AUC scores(%) on 8 downstream datasets following (Hou et al. 2022; You et al. 2020).
Analysis
From Table 11, you can find SGL surpasses other methods and reaches SOTA performance on eight datasets. It reaches around 8% improvement on No-pretrain GIN which proves our pre-training method SGL is strong enough. And it can enhance GraphMAE a lot on many datasets (e.g., 3.5% and 2.3% gains on MUV and SIDER datasets) which further shows the effectiveness of combining contrastive and generative methods.
Node Classification
Datasets
We choose three commonly used citation datasets, i.e., Cora, CiteSeer and PubMed (Sen et al. 2008), in this part. These datasets consist of nodes representing different papers and edges signifying citation relationships between them. Each node is associated with a bag-of-words representation of the corresponding paper, and the label corresponds to the academic topic of the paper. Further details about the dataset statistics can be found in Table 12.
Experimental Setup
We follow the same experimental setup in GraphMAE (Hou et al. 2022). After finishing unsupervised pre-training, we will freeze the pre-trained model and exclusively train a linear classifier using labeled training data for downstream tasks. The specific hyperparameters are listed in Table 13. And we report the mean accuracy with standard deviation after 20 runs. For simplicity, we only compare with GraphMAE here because it is a strong baseline.
Furthermore, in this section, we assess the performance of our prompt method in a few-shot setting111In order to achieve prompt tuning on node classification, we will extract ego graphs for each target node and subsequently assign ego graph labels based on the corresponding node labels. The subsequent steps mirror the graph-level process.. The model configuration remains consistent with the above. For standard fine-tuning approaches, all parameters are fine-tuned using downstream data. In contrast, for prompt methods, pre-trained models are kept frozen, with only additional parameters for the downstream task will be tuned.
Analysis
Table 14 shows the results of unsupervised representation learning and illustrates the consistent superiority of our SGL pre-training method across these three datasets, which proves the effectiveness of combining contrastive and generative methods.
Moreover, our prompt tuning approach for node classification also demonstrates its effectiveness, as seen in Table 15. Notably, our prompt method achieves over 3% absolute improvement compared to standard fine-tuning approaches and over 2.4 % absolute improvement compared to another graph prompt method, GPPT, on 5-shot Cora dataset. These results show the superiority of our prompt design even on node-level tasks.
Evaluating SGL-PT on Other Graph-level Datasets
Due to space constraints, only partial results of prompt tuning are presented in the main content. In this section, additional experiments are exhibited in Table 16. The conclusions align with those stated in the main paper.
Applying SGL-PT on Other Pre-trained Methods
Our prompt method can also be applied to other contrastive methods, because our prompt loss is , by setting the coefficient as 0, our prompt aligns the objectives of pre-text and downstream tasks (i.e., contrastive loss). To validate this assertion, we substitute the pre-trained model with the GraphCL model. The outcomes are detailed in Table 17. Our approach consistently outperforms standard fine-tuning by a significant margin. For instance, on 3-shot DD dataset, SGL-PT surpasses GraphCL by more than 6% absolute improvement, and on 3-shot MUTAG dataset, it shows over 12% absolute improvement. Furthermore, our prompt method achieves superior results compared to other graph prompt methods, underscoring the efficacy of our prompt design. This success also validates the adaptability of our prompt approach to other contrastive methods.
Sensitivity Analysis
Analysis on
In the ablation study, we find that the dynamic queue is essential to integrate contrastive and generative methods better. In this section, we dig into the impact of the size of the dynamic queue. The results are shown in Figure 3, and we can find that the performance moves up with the increase of and the gets best performance when reaches 512 or 1024. This size is much larger than batch size (e.g., 8, 16), which shows the necessity of the dynamic queue to integrate the two branches.
Analysis on
During pre-training, the model captures both local and global information, so we set as 0.5. During prompt tuning, we focus on learning class prototypes used for classification, so we lower the ratio of ( as 0.1). Empirically, we found the performance of downstream tasks will drop with increasing (for DD, =. Acc=). So, we set as 0.1 for all datasets.
D. Parameter Efficiency
We compare trainable parameters with GIN, GPPT, GPF, ProG and GraphPrompt and SGL-PT in Table 18. The ‘GIN’ entry represents the number of parameters required to train a GIN model from scratch, which is significantly more than other methods. For prompt methods, we will freeze the pre-trained models and only calculate the downstream task-specific parameters. GPPT requires task tokens (used for classification) and a structure token involving attention module (used for neighbor aggregation) to accomplish downstream tasks, resulting in significantly more trainable parameters than other prompt methods. For GPF, it needs to initialize a learnable vector as well as a downstream classifier. For ProG, it needs to initialize a learnable prompt graph (i.e., a set of learnable vectors). For GraphPrompt, it introduces a learnable weighted summation readout function but needs to execute clustering to obtain class prototypes. As for SGL-PT, we only require training class prototypes for classification.
GraphPrompt, while having the fewest trainable parameters, does require additional computational time due to its clustering process for obtaining class prototypes. This step becomes increasingly time-consuming as the downstream data expands. For example, in semi-supervised settings, executing GraphPrompt on NCI-H23 takes more than a day, while other methods complete the task in just two hours.
To conclude, while GraphPrompt has fewer trainable parameters than SGL-PT, the efficiency and effectiveness of SGL-PT surpass GraphPrompt.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1Adhikari et al. (2018) Adhikari, B.; Zhang, Y.; Ramakrishnan, N.; and Prakash, B. A. 2018. Sub 2vec: Feature learning for subgraphs. In Pacific-Asia Conference on Knowledge Discovery and Data Mining , 170–182. Springer.
- 2Alemi et al. (2016) Alemi, A. A.; Fischer, I.; Dillon, J. V.; and Murphy, K. 2016. Deep variational information bottleneck. ar Xiv preprint ar Xiv:1612.00410 .
- 3Chang and Lin (2011) Chang, C.-C.; and Lin, C.-J. 2011. LIBSVM: A Library for Support Vector Machines. ACM Trans. Intell. Syst. Technol. , 2(3).
- 4Chen et al. (2020) Chen, T.; Kornblith, S.; Norouzi, M.; and Hinton, G. 2020. A simple framework for contrastive learning of visual representations. In International conference on machine learning , 1597–1607. PMLR.
- 5Cui et al. (2022) Cui, G.; Hu, S.; Ding, N.; Huang, L.; and Liu, Z. 2022. Prototypical Verbalizer for Prompt-based Few-shot Tuning. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) , 7014–7024.
- 6Davis and Goadrich (2006) Davis, J.; and Goadrich, M. 2006. The relationship between Precision-Recall and ROC curves. In Proceedings of the 23rd international conference on Machine learning , 233–240.
- 7Fang et al. (2022) Fang, T.; Zhang, Y.; Yang, Y.; and Wang, C. 2022. Prompt tuning for graph neural networks. ar Xiv preprint ar Xiv:2209.15240 .
- 8Fey and Lenssen (2019) Fey, M.; and Lenssen, J. E. 2019. Fast graph representation learning with Py Torch Geometric. Ar Xiv preprint .
