#!/usr/bin/env python3
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.

import argparse
from datetime import timedelta
import numpy as np
import model_runtime
import popef
"""
The example shows loading a model from PopEF files and sending a single
inference request.
"""


def main():
    parser = argparse.ArgumentParser("Model runner simple example.")
    parser.add_argument(
        "-p",
        "--popef",
        type=str,
        metavar='popef_file_path',
        help="A collection of PopEF files containing the model.",
        nargs='+',
        required=True)
    args = parser.parse_args()

    for model_file in args.popef:
        reader = popef.Reader()
        reader.parseFile(model_file)

    model = popef.ModelBuilder(reader).createModel()

    dm = model_runtime.DeviceManager()
    wait_config = model_runtime.DeviceWaitConfig(
        model_runtime.DeviceWaitStrategy.WAIT_WITH_TIMEOUT,
        timeout=timedelta(seconds=600),
        sleepTime=timedelta(seconds=1))
    device = dm.getDevice(model, wait_config=wait_config)
    print("Device acquired:", device)

    session = model_runtime.Session(model)
    print("Created Session.")

    session.bindToDevice(device)

    queue_manager = session.createQueueManager()
    print("Created QueueManager.")

    user_input_anchors = session.getUserInputAnchors()
    input_data = prepare_io_memory(user_input_anchors)

    for anchor, tensor in zip(user_input_anchors, input_data):
        anchor_name = anchor.name()
        print("Enqueue input anchor ", anchor_name)
        queue_manager.inputs[anchor_name].enqueue(tensor)

    user_output_anchors = session.getUserOutputAnchors()
    output_data = prepare_io_memory(user_input_anchors)

    for anchor, tensor in zip(user_output_anchors, output_data):
        anchor_name = anchor.name()
        print("Enqueue output anchor ", anchor_name)
        queue_manager.outputs[anchor_name].enqueue(tensor)

    print("Running load programs")
    session.runLoadPrograms()

    print("Running main programs")
    session.runMainPrograms()

    print("Execution finished. Results available.")

    print("Success: exiting")
    return 0


def prepare_io_memory(anchors):
    data = []
    for anchor in anchors:
        info = anchor.tensorInfo()
        shape = info.shape()
        data.append(np.random.randn(*shape).astype(info.numpyDType()))
    return data


if __name__ == "__main__":
    main()
