# ----------------------------------------------------------------------------
#  SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its
#  affiliates <open-source-office@arm.com>
#  SPDX-License-Identifier: Apache-2.0
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# ----------------------------------------------------------------------------

import typing
import urllib.parse
from pathlib import Path

import executorch.extension.pybindings.portable_lib  # noqa
import executorch.kernels.quantized  # noqa
import requests
import torch
from executorch.exir import (
    EdgeCompileConfig,
    ExecutorchBackendConfig,
    to_edge_transform_and_lower,
)
from executorch.extension.export_util.utils import save_pte_program
from executorch.runtime import Runtime

current_dir = Path(__file__).parent
conformer_pte_path = current_dir / "conformer.pte"
conformer_model_url = "https://huggingface.co/Arm/stt_en_conformer_executorch_small/resolve/main/Conformer_ArmQuantizer_quant_exported_program.pt2"


def download_file(url: str, output_path: Path) -> Path:
    if output_path.is_dir():
        parsed_url = urllib.parse.urlparse(url)
        output_path = output_path / Path(parsed_url.path).name
    if not output_path.exists():
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(output_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
    return output_path


def generate_pte(exported_program_path: Path, output_path: Path):
    if not output_path.exists():
        quant_exported_program = torch.export.load(exported_program_path)
        edge_program_manager = to_edge_transform_and_lower(
            quant_exported_program,
            partitioner=None,
            compile_config=EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        )
        executorch_program_manager = edge_program_manager.to_executorch(
            config=ExecutorchBackendConfig(extract_delegate_segments=False)
        )
        save_pte_program(executorch_program_manager, str(output_path))


def run_inference(pte_path: Path) -> typing.Sequence[torch.Tensor]:
    runtime = Runtime.get()
    method = runtime.load_program(pte_path).load_method("forward")
    outputs = method.execute([torch.randn(1, 1500, 80)])
    return outputs


if __name__ == "__main__":
    downloaded_conformer_path = download_file(conformer_model_url, current_dir)
    generate_pte(exported_program_path=downloaded_conformer_path, output_path=conformer_pte_path)
    run_inference(conformer_pte_path)
