Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/hyrax/verbs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def run(self):

from hyrax.config_utils import create_results_dir, log_runtime_config
from hyrax.gpu_monitor import GpuMonitor
from hyrax.model_exporters import export_to_onnx
from hyrax.pytorch_ignite import (
create_trainer,
create_validator,
Expand Down Expand Up @@ -124,19 +123,6 @@ def run(self):
logger.info("Finished Training")
tensorboardx_logger.close()

context = {
"ml_framework": "pytorch",
"results_dir": results_dir,
}

# Get a sample of input data. If the data is labeled, only return the input data.
batch_sample = next(iter(train_data_loader))
if isinstance(batch_sample, dict):
batch_sample = model.to_tensor(batch_sample)
sample = batch_sample[0] if isinstance(batch_sample, (list, tuple)) else batch_sample

export_to_onnx(model, sample, config, context)

return model

@staticmethod
Expand Down
Loading