diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ef1f3072 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# ignore __pycache__ folder +__pycache__/ \ No newline at end of file diff --git a/tokenizer.py b/tokenizer.py index 8a64c905..c2befc26 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -18,16 +18,23 @@ def __init__(self, tokenizer_model_path): self.pad_token_id = 0 # self.tokenizer.pad_id() self.newline_token_id = 13 - # Encode string - def encode(self, text, return_mask = False, max_seq_len = 2048): + def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False): if isinstance(text, list): # text is a list of strings list_ids = self.tokenizer.EncodeAsIds(text) + + # pad bos and eos + + if add_bos: + for ids in list_ids: ids.insert(0, self.bos_token_id) + if add_eos: + for ids in list_ids: ids.append(self.eos_token_id) + max_length = max([len(ids) for ids in list_ids]) needs_mask = False @@ -56,6 +63,14 @@ def encode(self, text, return_mask = False, max_seq_len = 2048): # text is a single string ids = self.tokenizer.EncodeAsIds(text) + + # pad bos and eos + + if add_bos: + ids = [self.bos_token_id] + ids + if add_eos: + ids = ids + [self.eos_token_id] + stacked_ids = torch.tensor(ids).unsqueeze(0) if return_mask: