#!/usr/bin/env python3
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.

import argparse
from datetime import timedelta
import numpy as np
import model_runtime
import popef
"""
The example shows loading a model from PopEF files, creating 2 model replicas
and sending inference requests to each of them.
"""


def main():
    parser = argparse.ArgumentParser("Model runner simple example.")
    parser.add_argument(
        "-p",
        "--popef",
        type=str,
        metavar='popef_file_path',
        help="A collection of PopEF files containing the model.",
        nargs='+',
        required=True)
    args = parser.parse_args()

    num_replicas = 2
    # Create model runner
    config = model_runtime.ModelRunnerConfig()
    config.replication_factor = num_replicas
    config.device_wait_config = model_runtime.DeviceWaitConfig(
        model_runtime.DeviceWaitStrategy.WAIT_WITH_TIMEOUT,
        timeout=timedelta(seconds=600),
        sleepTime=timedelta(seconds=1))

    print("Creating ModelRunner with", config)
    runner = model_runtime.ModelRunner(model_runtime.PopefPaths(args.popef),
                                       config=config)

    input_descriptions = runner.getExecuteInputs()

    input = model_runtime.InputMemoryView()

    print("Preparing input tensors:")
    input_descriptions = runner.getExecuteInputs()
    input_tensors = [
        np.random.randn(*input_desc.shape).astype(input_desc.numpy_data_type())
        for input_desc in input_descriptions
    ]
    input_view = model_runtime.InputMemoryView()

    for input_desc, input_tensor in zip(input_descriptions, input_tensors):
        print("\tname:", input_desc.name, "shape:", input_tensor.shape,
              "dtype:", input_tensor.dtype)
        input_view[input_desc.name] = input_tensor

    for replica_id in range(num_replicas):
        print("Sending single synchronous request with empty data - replica",
              replica_id, ".")
        result = runner.execute(input_view, replica_id=replica_id)
        output_descriptions = runner.getExecuteOutputs()

        print("Processing output tensors - replica", replica_id, ":")
        for output_desc in output_descriptions:
            output_tensor = np.frombuffer(
                result[output_desc.name],
                dtype=output_desc.numpy_data_type()).reshape(output_desc.shape)
            print("\tname:", output_desc.name, "shape:", output_tensor.shape,
                  "dtype:", output_tensor.dtype, "\n", output_tensor)

    print("Success: exiting")
    return 0


if __name__ == "__main__":
    main()
