#!/usr/bin/env python3
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.

import argparse
import threading
from datetime import timedelta
import numpy as np
import model_runtime
import popef
"""
The example shows loading a model from PopEF files and sending inference
requests to the same model by multiple threads.
"""


def main():
    parser = argparse.ArgumentParser("Model runner simple example.")
    parser.add_argument(
        "-p",
        "--popef",
        type=str,
        metavar='popef_file_path',
        help="A collection of PopEF files containing the model.",
        nargs='+',
        required=True)
    args = parser.parse_args()

    config = model_runtime.ModelRunnerConfig()
    config.thread_safe = True
    config.device_wait_config = model_runtime.DeviceWaitConfig(
        model_runtime.DeviceWaitStrategy.WAIT_WITH_TIMEOUT,
        timeout=timedelta(seconds=600),
        sleepTime=timedelta(seconds=1))

    print("Creating ModelRunner with", config)
    model_runner = model_runtime.ModelRunner(model_runtime.PopefPaths(
        args.popef),
                                             config=config)
    num_workers = 4
    print("Starting", num_workers, "worker threads.")
    threads = [
        threading.Thread(target=workerMain, args=(model_runner, worker_id))
        for worker_id in range(num_workers)
    ]

    for thread in threads:
        thread.start()

    for thread in threads:
        thread.join()

    print("Success: exiting")
    return 0


def workerMain(model_runner, worker_id):
    print("Worker", worker_id, "Starting workerMain()")
    num_requests = 5

    input_descriptions = model_runner.getExecuteInputs()
    input_requests = []

    print("Worker", worker_id, "Allocating input tensors for", num_requests,
          "requests", input_descriptions)
    for _ in range(num_requests):
        input_requests.append([
            np.random.randn(*input_desc.shape).astype(
                input_desc.numpy_data_type())
            for input_desc in input_descriptions
        ])

    futures = []

    for req_id in range(num_requests):
        print("Worker", worker_id, "Sending asynchronous request. Request id",
              req_id)
        input_view = model_runtime.InputMemoryView()
        for input_desc, input_tensor in zip(input_descriptions,
                                            input_requests[req_id]):
            input_view[input_desc.name] = input_tensor
        futures.append(model_runner.executeAsync(input_view))

    print("Worker", worker_id, "Processing outputs.")
    for req_id, future in enumerate(futures):
        print("Worker", worker_id, "Waiting for the result - request", req_id)
        future.wait()
        print("Worker", worker_id, "Result available - request", req_id)


if __name__ == "__main__":
    main()
