Skip to content

Conversation

@xnuohz
Copy link
Contributor

@xnuohz xnuohz commented Oct 24, 2024

Issue

Feature Summary

  • Add GitMolDataset
  • Add GITMol as GNN & LLM Co-training model to PyG
  • Add an example for pre-training
  • Limited hardware resources, so the full training pipeline was not tested
  • Multi modal cross attention shares the same weight, not aligned with the original paper

@xnuohz xnuohz changed the title [WIP] GIT-Mol Add GIT-Mol Nov 3, 2024
@xnuohz
Copy link
Contributor Author

xnuohz commented Nov 15, 2024

@puririshi98 Local CI passed.
image

@puririshi98
Copy link
Contributor

@xnuohz i am confused, how is the data meant to be downloaded if the download function is pass and the process function does not have any dowload step.

root@keystone-dvt1d-023-114:/workspace/pytorch_geometric# python3 examples/llm/git_mol.py 
Processing...
Traceback (most recent call last):
  File "/workspace/pytorch_geometric/examples/llm/git_mol.py", line 127, in <module>
    train(
  File "/workspace/pytorch_geometric/examples/llm/git_mol.py", line 41, in train
    train_dataset = GitMolDataset(path, split=0)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch_geometric/datasets/git_mol_dataset.py", line 71, in __init__
    super().__init__(root, transform, pre_transform, pre_filter,
  File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/in_memory_dataset.py", line 81, in __init__
    super().__init__(root, transform, pre_transform, pre_filter, log,
  File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/dataset.py", line 115, in __init__
    self._process()
  File "/usr/local/lib/python3.12/dist-packages/torch_geometric/data/dataset.py", line 262, in _process
    self.process()
  File "/usr/local/lib/python3.12/dist-packages/torch_geometric/datasets/git_mol_dataset.py", line 161, in process
    data = pd.read_pickle(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pandas/io/pickle.py", line 185, in read_pickle
    with get_handle(
         ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pandas/io/common.py", line 882, in get_handle
    handle = open(handle, ioargs.mode)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/workspace/pytorch_geometric/data/GITMol/raw/igcdata_toy/train_3500.pkl'

I know you are low on compute so feel free to ping me on slack when you have a change and i can test it quickly to make sure we have something working

@puririshi98 puririshi98 self-requested a review November 22, 2024 17:37
Copy link
Contributor

@puririshi98 puririshi98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
in NVIDIA container:

/workspace/pytorch_geometric# python3 examples/llm/git_mol.py 
Downloading https://drive.usercontent.google.com/download?id=1loBXabD6ncAFY-vanRsVtRUSFkEtBweg&confirm=t
Extracting /workspace/pytorch_geometric/data/GITMol/raw/gitmol.zip
Processing...
  0%|                                                                                                                                                             | 0/3610 [00:00<?, ?it/s]/usr/local/lib/python3.12/dist-packages/torchvision/transforms/_functional_pil.py:113: RuntimeWarning: invalid value encountered in cast
  np_h += np.array(hue_factor * 255).astype(np.uint8)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3610/3610 [00:08<00:00, 421.48it/s]
Done!
Using existing file gitmol.zip
Extracting /workspace/pytorch_geometric/data/GITMol/raw/gitmol.zip
Processing...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:00<00:00, 721.40it/s]
Done!
Using existing file gitmol.zip
Extracting /workspace/pytorch_geometric/data/GITMol/raw/gitmol.zip
Processing...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:00<00:00, 786.91it/s]
Done!
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 385/385 [00:00<00:00, 3.86MB/s]
vocab.txt: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 228k/228k [00:00<00:00, 3.47MB/s]
pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 442M/442M [00:06<00:00, 64.4MB/s]
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 71.8k/71.8k [00:00<00:00, 65.6MB/s]
Some weights of BertModel were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.self.value.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training beginning...
Epoch: 1|3:   0%|                                                                                                                                                  | 0/902 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch: 1|3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 902/902 [04:31<00:00,  3.32it/s]
Epoch: 1|3, Train loss: 0.886526, Val loss: 1.076906
Epoch: 2|3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 902/902 [04:31<00:00,  3.32it/s]
Epoch: 2|3, Train loss: 0.797038, Val loss: 1.134945
Epoch: 3|3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 902/902 [04:34<00:00,  3.29it/s]
Epoch: 3|3, Train loss: 0.776674, Val loss: 1.252366
/usr/local/lib/python3.12/dist-packages/torch/cuda/memory.py:369: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Test loss: 1.255684

@puririshi98 puririshi98 merged commit f732427 into pyg-team:master Nov 25, 2024
16 checks passed
@xnuohz xnuohz deleted the git-mol branch November 25, 2024 05:49
mattjhayes3 pushed a commit to mattjhayes3/pytorch_geometric that referenced this pull request Dec 14, 2024
### Issue
- pyg-team#9694
- pyg-team#9700

### Feature Summary

- Add `GitMolDataset`
- Add `GITMol` as GNN & LLM Co-training model to PyG
- Add an example for pre-training
- Limited hardware resources, so the full training pipeline was not
tested
- Multi modal cross attention shares the same weight, not aligned with
the original paper

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rishi Puri <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants