Skip to content

Add utility func to get path to assets #777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/common/assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import shutil
import atexit
import tempfile
from pathlib import Path

_ASSET_DIR = (Path(__file__).parent.parent / "asset").resolve()

_TEMP_DIR = None


def _init_temp_dir():
"""Initialize temporary directory and register clean up at the end of test."""
global _TEMP_DIR
_TEMP_DIR = tempfile.TemporaryDirectory() # noqa
atexit.register(_TEMP_DIR.cleanup)


def get_asset_path(*path_components, use_temp_dir=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make this a context with get_assest that'll then also cleans up the temp file if it goes out of scope (if that is intended).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the typical approach, but that way user of this function has to be aware of temporary files and directories. This way, they are abstracted out and users need not care.

"""Get the path to the file under `test/assets` directory.
When `use_temp_dir` is True, the asset is copied to a temporary location and
path to the temporary file is returned.
"""
path = str(_ASSET_DIR.joinpath(*path_components))
if not use_temp_dir:
return path

if _TEMP_DIR is None:
_init_temp_dir()
tgt = os.path.join(_TEMP_DIR.name, path_components[-1])
shutil.copy(path, tgt)
return tgt
14 changes: 10 additions & 4 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
)

from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path


class TestFunctional(TorchtextTestCase):
def test_generate_sp_model(self):
# Test the function to train a sentencepiece tokenizer

data_path = 'test/asset/text_normalization_ag_news_test.csv'
# buck (fb internal) generates test environment which contains ',' in its path.
# SentencePieceTrainer considers such path as comma-delimited file list.
# So as workaround we copy the asset data to temporary directory and load it from there.
data_path = get_asset_path(
'text_normalization_ag_news_test.csv',
use_temp_dir=True)
generate_sp_model(data_path,
vocab_size=23456,
model_prefix='spm_user')
Expand All @@ -38,7 +44,7 @@ def test_generate_sp_model(self):

def test_sentencepiece_numericalizer(self):
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
model_path = 'test/asset/spm_example.model'
model_path = get_asset_path('spm_example.model')
sp_model = load_sp_model(model_path)
self.assertEqual(sp_model.GetPieceSize(), 20000)
spm_generator = sentencepiece_numericalizer(sp_model)
Expand All @@ -52,7 +58,7 @@ def test_sentencepiece_numericalizer(self):
def test_sentencepiece_tokenizer(self):

test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
model_path = 'test/asset/spm_example.model'
model_path = get_asset_path('spm_example.model')
sp_model = load_sp_model(model_path)
self.assertEqual(sp_model.GetPieceSize(), 20000)
spm_generator = sentencepiece_tokenizer(sp_model)
Expand Down Expand Up @@ -99,7 +105,7 @@ def encode_as_pieces(self, input: str):

class TestScriptableSP(unittest.TestCase):
def setUp(self):
model_path = 'test/asset/spm_example.model'
model_path = get_asset_path('spm_example.model')
with tempfile.NamedTemporaryFile() as file:
torch.jit.script(ScriptableSP(model_path)).save(file.name)
self.model = torch.jit.load(file.name)
Expand Down
16 changes: 9 additions & 7 deletions test/data/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import io
import unittest

import torchtext.data as data
import pytest
from ..common.torchtext_test_case import TorchtextTestCase
from torchtext.utils import unicode_csv_reader
import io

from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path


class TestUtils(TorchtextTestCase):
Expand All @@ -21,8 +24,7 @@ def test_get_tokenizer_spacy(self):

# TODO: Remove this once issue was been resolved.
# TODO# Add nltk data back in build_tools/travis/install.sh.
@pytest.mark.skip(reason=("Impractically slow! "
"https://github.com/alvations/sacremoses/issues/61"))
@unittest.skip("Impractically slow! https://github.com/alvations/sacremoses/issues/61")
def test_get_tokenizer_moses(self):
# Test Moses option.
# Note that internally, MosesTokenizer converts to unicode if applicable
Expand Down Expand Up @@ -54,13 +56,13 @@ def test_text_nomalize_function(self):
test_lines = []

tokenizer = data.get_tokenizer("basic_english")
data_path = 'test/asset/text_normalization_ag_news_test.csv'
data_path = get_asset_path('text_normalization_ag_news_test.csv')
with io.open(data_path, encoding="utf8") as f:
reader = unicode_csv_reader(f)
for row in reader:
test_lines.append(tokenizer(' , '.join(row)))

data_path = 'test/asset/text_normalization_ag_news_ref_results.test'
data_path = get_asset_path('text_normalization_ag_news_ref_results.test')
with io.open(data_path, encoding="utf8") as ref_data:
for line in ref_data:
line = line.split()
Expand Down