Skip to content

Commit 8831694

Browse files
authored
Update TensorFlow checkpoint loading for compatibility with tensorflow v2.X
1 parent 44dee49 commit 8831694

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_geometric/nn/models/dimenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def from_qm9_pretrained(
593593
download_url(f'{url}/ckpt.index', path)
594594

595595
path = osp.join(path, 'ckpt')
596-
reader = tf.train.load_checkpoint(path)
596+
reader = tf.compat.v1.train.load_checkpoint(path)
597597

598598
model = cls(
599599
hidden_channels=128,
@@ -865,7 +865,7 @@ def from_qm9_pretrained(
865865
download_url(f'{url}/ckpt.index', path)
866866

867867
path = osp.join(path, 'ckpt')
868-
reader = tf.train.load_checkpoint(path)
868+
reader = tf.compat.v1.train.load_checkpoint(path)
869869

870870
# Configuration from DimeNet++:
871871
# https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml

0 commit comments

Comments
 (0)