#!/usr/bin/env python3
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.

import os
import argparse
from datetime import timedelta
import numpy as np
import model_runtime
import popef
"""
The example shows loading a model from PopEF file and binding constant tensor
value to one of the inputs. The example is based on the PopEF file generated
by `model_runtime_example_generate_simple_popef` example. Generated PopEF
file consists simple model:

output = (A * weights) + B

where A and B are stream inputs, weights is a tensor saved as popef::TensorData
and  output is result stream output tensor.
"""


def main():
    parser = argparse.ArgumentParser("Model runner simple example.")
    parser.add_argument(
        "-p",
        "--popef",
        type=str,
        metavar='popef_file_path',
        help="A collection of PopEF files containing the model.",
        nargs='+',
        required=True)
    args = parser.parse_args()
    model = load_model(args.popef)

    frozen_input_name = "tensor_B"
    print("Looking for tensor", frozen_input_name, "inside PopEF model.")
    tensor_b_anchor = popef.Anchor()

    for anchor in model.metadata.anchors():
        if anchor.name() == frozen_input_name:
            tensor_b_anchor = anchor
            break
    else:
        raise Exception(f'Anchor {frozen_input_name} not found inside givem '
                        'model. Please make sure that PopEF was generated by '
                        '`model_runtime_example_generate_simple_popef`')

    print("Generating", frozen_input_name, "random values")
    tensor_b_info = tensor_b_anchor.tensorInfo()
    tensor_b = np.random.randn(*tensor_b_info.shape()).astype(
        tensor_b_info.numpyDType())

    config = model_runtime.ModelRunnerConfig()

    frozen_inputs = model_runtime.InputMemoryView()
    frozen_inputs[frozen_input_name] = tensor_b
    config.frozen_inputs = frozen_inputs

    print(
        "Tensor", frozen_input_name, "is frozen - will be treated as "
        "constant in each execution request.")
    config.device_wait_config = model_runtime.DeviceWaitConfig(
        model_runtime.DeviceWaitStrategy.WAIT_WITH_TIMEOUT,
        timeout=timedelta(seconds=600),
        sleepTime=timedelta(seconds=1))

    model_runner = model_runtime.ModelRunner(model, config=config)

    print("Preparing input tensors:")
    input_descriptions = model_runner.getExecuteInputs()
    input_tensors = [
        np.random.randn(*input_desc.shape).astype(input_desc.numpy_data_type())
        for input_desc in input_descriptions
    ]
    input_view = model_runtime.InputMemoryView()

    for input_desc, input_tensor in zip(input_descriptions, input_tensors):
        print("\tname:", input_desc.name, "shape:", input_tensor.shape,
              "dtype:", input_tensor.dtype)
        input_view[input_desc.name] = input_tensor

    print("Sending single synchronous request with empty data.")
    result = model_runner.execute(input_view)
    output_descriptions = model_runner.getExecuteOutputs()

    print("Processing output tensors:")
    for output_desc in output_descriptions:
        output_tensor = np.frombuffer(
            result[output_desc.name],
            dtype=output_desc.numpy_data_type()).reshape(output_desc.shape)
        print("\tname:", output_desc.name, "shape:", output_tensor.shape,
              "dtype:", output_tensor.dtype, "\n", output_tensor)

    print("Success: exiting")

    return 0


def load_model(popef_paths):
    for model_file in popef_paths:
        assert os.path.isfile(model_file) is True
        reader = popef.Reader()
        reader.parseFile(model_file)

        meta = reader.metadata()
        exec = reader.executables()
        return popef.ModelBuilder(reader).createModel()


if __name__ == "__main__":
    main()
