@@ -137,10 +137,13 @@ def __init__(
137137 self .token_on_disk = token_on_disk
138138 self .tokenize_batch_size = tokenize_batch_size
139139 self ._token = self .tokenize_graph (self .tokenize_batch_size )
140- self ._llm_explanation_token = self .tokenize_graph (
141- self .tokenize_batch_size , text_type = 'llm_explanation' )
142- self ._all_token = self .tokenize_graph (self .tokenize_batch_size ,
143- text_type = 'all' )
140+ self ._llm_explanation_token : Dict [str , Tensor ] = {}
141+ self ._all_token : Dict [str , Tensor ] = {}
142+ if self .name in self .llm_explanation_id :
143+ self ._llm_explanation_token = self .tokenize_graph (
144+ self .tokenize_batch_size , text_type = 'llm_explanation' )
145+ self ._all_token = self .tokenize_graph (self .tokenize_batch_size ,
146+ text_type = 'all' )
144147 self .__num_classes__ = dataset .num_classes
145148
146149 @property
@@ -170,14 +173,16 @@ def token(self) -> Dict[str, Tensor]:
170173
171174 @property
172175 def llm_explanation_token (self ) -> Dict [str , Tensor ]:
173- if self ._llm_explanation_token is None : # lazy load
176+ if self ._llm_explanation_token is None and \
177+ self .name in self .llm_explanation_id :
174178 self ._llm_explanation_token = self .tokenize_graph (
175179 text_type = 'llm_explanation' )
176180 return self ._llm_explanation_token
177181
178182 @property
179183 def all_token (self ) -> Dict [str , Tensor ]:
180- if self ._all_token is None : # lazy load
184+ if self ._all_token is None and \
185+ self .name in self .llm_explanation_id :
181186 self ._all_token = self .tokenize_graph (text_type = 'all' )
182187 return self ._all_token
183188
@@ -230,13 +235,15 @@ def download(self) -> None:
230235 filename = 'node-text.csv.gz' ,
231236 log = True )
232237 self .text = list (read_csv (raw_text_path )['text' ])
233- print ('downloading llm explanations' )
234- llm_explanation_path = download_google_url (
235- id = self .llm_explanation_id [self .name ], folder = f'{ self .root } /raw' ,
236- filename = 'node-gpt-response.csv.gz' , log = True )
237- self .llm_explanation = list (read_csv (llm_explanation_path )['text' ])
238- print ('downloading llm predictions' )
239- fs .cp (f'{ self .llm_prediction_url } /{ self .name } .csv' , self .raw_dir )
238+ if self .name in self .llm_explanation_id :
239+ print ('downloading llm explanations' )
240+ llm_explanation_path = download_google_url (
241+ id = self .llm_explanation_id [self .name ],
242+ folder = f'{ self .root } /raw' , filename = 'node-gpt-response.csv.gz' ,
243+ log = True )
244+ self .llm_explanation = list (read_csv (llm_explanation_path )['text' ])
245+ print ('downloading llm predictions' )
246+ fs .cp (f'{ self .llm_prediction_url } /{ self .name } .csv' , self .raw_dir )
240247
241248 def process (self ) -> None :
242249 # process Title and Abstraction
@@ -276,20 +283,21 @@ def process(self) -> None:
276283 for i , pred in enumerate (preds ):
277284 pl [i ][:len (pred )] = torch .tensor (
278285 pred [:self .llm_prediction_topk ], dtype = torch .long ) + 1
286+
287+ if self .llm_explanation is None or pl is None :
288+ raise ValueError (
289+ "The TAGDataset only have ogbn-arxiv LLM explanations"
290+ "and predictions in default. The llm explanation and"
291+ "prediction of each node is not specified.Please pass in"
292+ "'llm_explanation' and 'llm_prediction' when"
293+ "convert your dataset to Text Attribute Graph Dataset" )
279294 elif self .name in self .llm_explanation_id :
280295 self .download ()
281296 else :
282297 print (
283298 'The dataset is not ogbn-arxiv,'
284299 'please pass in your llm explanation list to `llm_explanation`'
285300 'and llm prediction list to `llm_prediction`' )
286- if self .llm_explanation is None or pl is None :
287- raise ValueError (
288- "The TAGDataset only have ogbn-arxiv LLM explanations"
289- "and predictions in default. The llm explanation and"
290- "prediction of each node is not specified."
291- "Please pass in 'llm_explanation' and 'llm_prediction' when"
292- "convert your dataset to Text Attribute Graph Dataset" )
293301
294302 def save_node_text (self , text : List [str ]) -> None :
295303 node_text_path = osp .join (self .root , 'raw' , 'node-text.csv.gz' )
0 commit comments