Description
The _tokenize_batch
method in the dataloader.py fails when processing the batch[col_name]
field, as its type is pyarrow.ListArray
. The .to() function is called on this field, which results in an error because pyarrow.ListArray
does not support the .to() method (unless I'm missing something in the implementation, I'm using the env setup provided in the repo)
Expected Behavior
The batch[col_name]
field should be properly processed as a torch.Tensor
or converted to one before calling .to()
.
Observed Behavior
The code throws an error when attempting to execute .to(self.gang.device)
because pyarrow.ListArray
does not have a .to()
method.
Code Snippet
embs = [x.to(self.gang.device).to(dtype) for x in batch[col_name]]
Potential Solution
The batch[col_name]
field may need to be explicitly converted to a torch.Tensor
before calling .to()
. For example:
embs = [torch.Tensor(x.as_py()).to(self.gang.device).to(dtype) for x in batch[col_name]]