#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import base64
import gzip
import zlib
from urllib.parse import quote

import gevent
import gevent.pool
import rapidjson as json
from geventhttpclient import HTTPClient
from geventhttpclient.url import URL
from tritonclient.utils import raise_error

from .._client import InferenceServerClientBase
from .._request import Request
from ._infer_result import InferResult
from ._utils import _get_inference_request, _get_query_string, _raise_if_error


class InferAsyncRequest:
    """An object of InferAsyncRequest class is used to describe
    a handle to an ongoing asynchronous inference request.

    Parameters
    ----------
    greenlet : gevent.Greenlet
        The greenlet object which will provide the results.
        For further details about greenlets refer
        http://www.gevent.org/api/gevent.greenlet.html.

    verbose : bool
        If True generate verbose output. Default value is False.
    """

    def __init__(self, greenlet, verbose=False):
        self._greenlet = greenlet
        self._verbose = verbose

    def get_result(self, block=True, timeout=None):
        """Get the results of the associated asynchronous inference.
        Parameters
        ----------
        block : bool
            If block is True, the function will wait till the
            corresponding response is received from the server.
            Default value is True.
        timeout : int
            The maximum wait time for the function. This setting is
            ignored if the block is set False. Default is None,
            which means the function will block indefinitely till
            the corresponding response is received.

        Returns
        -------
        InferResult
            The object holding the result of the async inference.

        Raises
        ------
        InferenceServerException
            If server fails to perform inference or failed to respond
            within specified timeout.
        """

        try:
            response = self._greenlet.get(block=block, timeout=timeout)
        except gevent.Timeout as e:
            raise_error("failed to obtain inference response")

        _raise_if_error(response)
        return InferResult(response, self._verbose)


class InferenceServerClient(InferenceServerClientBase):
    """An InferenceServerClient object is used to perform any kind of
    communication with the InferenceServer using http protocol. None
    of the methods are thread safe. The object is intended to be used
    by a single thread and simultaneously calling different methods
    with different threads is not supported and will cause undefined
    behavior.

    Parameters
    ----------
    url : str
        The inference server name, port and optional base path
        in the following format: host:port/<base-path>, e.g.
        'localhost:8000'.

    verbose : bool
        If True generate verbose output. Default value is False.
    concurrency : int
        The number of connections to create for this client.
        Default value is 1.
    connection_timeout : float
        The timeout value for the connection. Default value
        is 60.0 sec.
    network_timeout : float
        The timeout value for the network. Default value is
        60.0 sec
    max_greenlets : int
        Determines the maximum allowed number of worker greenlets
        for handling asynchronous inference requests. Default value
        is None, which means there will be no restriction on the
        number of greenlets created.
    ssl : bool
        If True, channels the requests to encrypted https scheme.
        Some improper settings may cause connection to prematurely
        terminate with an unsuccessful handshake. See
        `ssl_context_factory` option for using secure default
        settings. Default value for this option is False.
    ssl_options : dict
        Any options supported by `ssl.wrap_socket` specified as
        dictionary. The argument is ignored if 'ssl' is specified
        False.
    ssl_context_factory : SSLContext callable
        It must be a callbable that returns a SSLContext. Set to
        `gevent.ssl.create_default_context` to use contexts with
        secure default settings. This should most likely resolve
        connection issues in a secure way. The default value for
        this option is None which directly wraps the socket with
        the options provided via `ssl_options`. The argument is
        ignored if 'ssl' is specified False.
    insecure : bool
        If True, then does not match the host name with the certificate.
        Default value is False. The argument is ignored if 'ssl' is
        specified False.

    Raises
        ------
        Exception
            If unable to create a client.

    """

    def __init__(
        self,
        url,
        verbose=False,
        concurrency=1,
        connection_timeout=60.0,
        network_timeout=60.0,
        max_greenlets=None,
        ssl=False,
        ssl_options=None,
        ssl_context_factory=None,
        insecure=False,
    ):
        super().__init__()
        if url.startswith("http://") or url.startswith("https://"):
            raise_error("url should not include the scheme")
        scheme = "https://" if ssl else "http://"
        self._parsed_url = URL(scheme + url)
        self._base_uri = self._parsed_url.request_uri.rstrip("/")
        self._client_stub = HTTPClient.from_url(
            self._parsed_url,
            concurrency=concurrency,
            connection_timeout=connection_timeout,
            network_timeout=network_timeout,
            ssl_options=ssl_options,
            ssl_context_factory=ssl_context_factory,
            insecure=insecure,
        )
        self._pool = gevent.pool.Pool(max_greenlets)
        self._verbose = verbose

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def __del__(self):
        self.close()

    def close(self):
        """Close the client. Any future calls to server
        will result in an Error.

        """
        self._pool.join()
        self._client_stub.close()

    def _get(self, request_uri, headers, query_params):
        """Issues the GET request to the server

         Parameters
        ----------
        request_uri: str
            The request URI to be used in GET request.
        headers: dict
            Additional HTTP headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        geventhttpclient.response.HTTPSocketPoolResponse
            The response from server.
        """
        request = Request(headers)
        self._call_plugin(request)

        # Update the headers based on plugin invocation
        headers = request.headers
        self._validate_headers(headers)

        if self._base_uri is not None:
            request_uri = self._base_uri + "/" + request_uri

        if query_params is not None:
            request_uri = request_uri + "?" + _get_query_string(query_params)

        if self._verbose:
            print("GET {}, headers {}".format(request_uri, headers))

        if headers is not None:
            response = self._client_stub.get(request_uri, headers=headers)
        else:
            response = self._client_stub.get(request_uri)

        if self._verbose:
            print(response)

        return response

    def _post(self, request_uri, request_body, headers, query_params):
        """Issues the POST request to the server

        Parameters
        ----------
        request_uri: str
            The request URI to be used in POST request.
        request_body: str
            The body of the request
        headers: dict
            Additional HTTP headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        geventhttpclient.response.HTTPSocketPoolResponse
            The response from server.
        """
        request = Request(headers)
        self._call_plugin(request)

        # Update the headers based on plugin invocation
        headers = request.headers
        self._validate_headers(headers)

        if self._base_uri is not None:
            request_uri = self._base_uri + "/" + request_uri

        if query_params is not None:
            request_uri = request_uri + "?" + _get_query_string(query_params)

        if self._verbose:
            print("POST {}, headers {}\n{}".format(request_uri, headers, request_body))

        if headers is not None:
            response = self._client_stub.post(
                request_uri=request_uri, body=request_body, headers=headers
            )
        else:
            response = self._client_stub.post(
                request_uri=request_uri, body=request_body
            )

        if self._verbose:
            print(response)

        return response

    def _validate_headers(self, headers):
        """Checks for any unsupported HTTP headers before processing a request.

        Parameters
        ----------
        headers: dict
            HTTP headers to validate before processing the request.

        Raises
        ------
        InferenceServerException
            If an unsupported HTTP header is included in a request.
        """
        if not headers:
            return

        # HTTP headers are case-insensitive, so force lowercase for comparison
        headers_lowercase = {k.lower(): v for k, v in headers.items()}
        # The python client lirary (and geventhttpclient) do not encode request
        # data based on "Transfer-Encoding" header, so reject this header if
        # included. Other libraries may do this encoding under the hood.
        # The python client library does expose special arguments to support
        # some "Content-Encoding" headers.
        if "transfer-encoding" in headers_lowercase:
            raise_error(
                "Unsupported HTTP header: 'Transfer-Encoding' is not "
                "supported in the Python client library. Use raw HTTP "
                "request libraries or the C++ client instead for this "
                "header."
            )

    def is_server_live(self, headers=None, query_params=None):
        """Contact the inference server and get liveness.

        Parameters
        ----------
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        bool
            True if server is live, False if server is not live.

        Raises
        ------
        Exception
            If unable to get liveness.

        """

        request_uri = "v2/health/live"
        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )

        return response.status_code == 200

    def is_server_ready(self, headers=None, query_params=None):
        """Contact the inference server and get readiness.

        Parameters
        ----------
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        bool
            True if server is ready, False if server is not ready.

        Raises
        ------
        Exception
            If unable to get readiness.

        """
        request_uri = "v2/health/ready"
        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )

        return response.status_code == 200

    def is_model_ready(
        self, model_name, model_version="", headers=None, query_params=None
    ):
        """Contact the inference server and get the readiness of specified model.

        Parameters
        ----------
        model_name: str
            The name of the model to check for readiness.
        model_version: str
            The version of the model to check for readiness. The default value
            is an empty string which means then the server will choose a version
            based on the model and internal policy.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        bool
            True if the model is ready, False if not ready.

        Raises
        ------
        Exception
            If unable to get model readiness.

        """
        if type(model_version) != str:
            raise_error("model version must be a string")
        if model_version != "":
            request_uri = "v2/models/{}/versions/{}/ready".format(
                quote(model_name), model_version
            )
        else:
            request_uri = "v2/models/{}/ready".format(quote(model_name))

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )

        return response.status_code == 200

    def get_server_metadata(self, headers=None, query_params=None):
        """Contact the inference server and get its metadata.

        Parameters
        ----------
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.

        Returns
        -------
        dict
            The JSON dict holding the metadata.

        Raises
        ------
        InferenceServerException
            If unable to get server metadata.

        """
        request_uri = "v2"
        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_model_metadata(
        self, model_name, model_version="", headers=None, query_params=None
    ):
        """Contact the inference server and get the metadata for specified model.

        Parameters
        ----------
        model_name: str
            The name of the model
        model_version: str
            The version of the model to get metadata. The default value
            is an empty string which means then the server will choose
            a version based on the model and internal policy.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the metadata.

        Raises
        ------
        InferenceServerException
            If unable to get model metadata.

        """
        if type(model_version) != str:
            raise_error("model version must be a string")
        if model_version != "":
            request_uri = "v2/models/{}/versions/{}".format(
                quote(model_name), model_version
            )
        else:
            request_uri = "v2/models/{}".format(quote(model_name))

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_model_config(
        self, model_name, model_version="", headers=None, query_params=None
    ):
        """Contact the inference server and get the configuration for specified model.

        Parameters
        ----------
        model_name: str
            The name of the model
        model_version: str
            The version of the model to get configuration. The default value
            is an empty string which means then the server will choose
            a version based on the model and internal policy.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the model config.

        Raises
        ------
        InferenceServerException
            If unable to get model configuration.

        """
        if model_version != "":
            request_uri = "v2/models/{}/versions/{}/config".format(
                quote(model_name), model_version
            )
        else:
            request_uri = "v2/models/{}/config".format(quote(model_name))

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_model_repository_index(self, headers=None, query_params=None):
        """Get the index of model repository contents

        Parameters
        ----------
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the model repository index.

        Raises
        ------
        InferenceServerException
            If unable to get the repository index.

        """
        request_uri = "v2/repository/index"
        response = self._post(
            request_uri=request_uri,
            request_body="",
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def load_model(
        self, model_name, headers=None, query_params=None, config=None, files=None
    ):
        """Request the inference server to load or reload specified model.

        Parameters
        ----------
        model_name : str
            The name of the model to be loaded.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.
        config: str
            Optional JSON representation of a model config provided for
            the load request, if provided, this config will be used for
            loading the model.
        files: dict
            Optional dictionary specifying file path (with "file:" prefix) in
            the override model directory to the file content as bytes.
            The files will form the model directory that the model will be
            loaded from. If specified, 'config' must be provided to be
            the model configuration of the override model directory.

        Raises
        ------
        InferenceServerException
            If unable to load the model.

        """
        request_uri = "v2/repository/models/{}/load".format(quote(model_name))
        load_request = {}
        if config is not None:
            if "parameters" not in load_request:
                load_request["parameters"] = {}
            load_request["parameters"]["config"] = config
        if files is not None:
            for path, content in files.items():
                if "parameters" not in load_request:
                    load_request["parameters"] = {}
                load_request["parameters"][path] = base64.b64encode(content)
        response = self._post(
            request_uri=request_uri,
            request_body=json.dumps(load_request),
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            print("Loaded model '{}'".format(model_name))

    def unload_model(
        self, model_name, headers=None, query_params=None, unload_dependents=False
    ):
        """Request the inference server to unload specified model.

        Parameters
        ----------
        model_name : str
            The name of the model to be unloaded.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction
        unload_dependents : bool
            Whether the dependents of the model should also be unloaded.

        Raises
        ------
        InferenceServerException
            If unable to unload the model.

        """
        request_uri = "v2/repository/models/{}/unload".format(quote(model_name))
        unload_request = {"parameters": {"unload_dependents": unload_dependents}}
        response = self._post(
            request_uri=request_uri,
            request_body=json.dumps(unload_request),
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            print("Loaded model '{}'".format(model_name))

    def get_inference_statistics(
        self, model_name="", model_version="", headers=None, query_params=None
    ):
        """Get the inference statistics for the specified model name and
        version.

        Parameters
        ----------
        model_name : str
            The name of the model to get statistics. The default value is
            an empty string, which means statistics of all models will
            be returned.
        model_version: str
            The version of the model to get inference statistics. The
            default value is an empty string which means then the server
            will return the statistics of all available model versions.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the model inference statistics.

        Raises
        ------
        InferenceServerException
            If unable to get the model inference statistics.

        """

        if model_name != "":
            if type(model_version) != str:
                raise_error("model version must be a string")
            if model_version != "":
                request_uri = "v2/models/{}/versions/{}/stats".format(
                    quote(model_name), model_version
                )
            else:
                request_uri = "v2/models/{}/stats".format(quote(model_name))
        else:
            request_uri = "v2/models/stats"

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def update_trace_settings(
        self, model_name=None, settings={}, headers=None, query_params=None
    ):
        """Update the trace settings for the specified model name, or
        global trace settings if model name is not given.
        Returns the trace settings after the update.

        Parameters
        ----------
        model_name : str
            The name of the model to update trace settings. Specifying None or
            empty string will update the global trace settings.
            The default value is None.
        settings: dict
            The new trace setting values. Only the settings listed will be
            updated. If a trace setting is listed in the dictionary with
            a value of 'None', that setting will be cleared.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the updated trace settings.

        Raises
        ------
        InferenceServerException
            If unable to update the trace settings.

        """

        if (model_name is not None) and (model_name != ""):
            request_uri = "v2/models/{}/trace/setting".format(quote(model_name))
        else:
            request_uri = "v2/trace/setting"

        response = self._post(
            request_uri=request_uri,
            request_body=json.dumps(settings),
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_trace_settings(self, model_name=None, headers=None, query_params=None):
        """Get the trace settings for the specified model name, or global trace
        settings if model name is not given

        Parameters
        ----------
        model_name : str
            The name of the model to get trace settings. Specifying None or
            empty string will return the global trace settings.
            The default value is None.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding the trace settings.

        Raises
        ------
        InferenceServerException
            If unable to get the trace settings.

        """

        if (model_name is not None) and (model_name != ""):
            request_uri = "v2/models/{}/trace/setting".format(quote(model_name))
        else:
            request_uri = "v2/trace/setting"

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def update_log_settings(self, settings, headers=None, query_params=None):
        """Update the global log settings of the Triton server.
        Parameters
        ----------
        settings: dict
            The new log setting values. Only the settings listed will be
            updated.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction
        Returns
        -------
        dict
            The JSON dict holding the updated log settings.
        Raises
        ------
        InferenceServerException
            If unable to update the log settings.
        """
        request_uri = "v2/logging"
        response = self._post(
            request_uri=request_uri,
            request_body=json.dumps(settings),
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_log_settings(self, headers=None, query_params=None):
        """Get the global log settings for the Triton server
        Parameters
        ----------
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction
        Returns
        -------
        dict
            The JSON dict holding the log settings.
        Raises
        ------
        InferenceServerException
            If unable to get the log settings.
        """

        request_uri = "v2/logging"

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def get_system_shared_memory_status(
        self, region_name="", headers=None, query_params=None
    ):
        """Request system shared memory status from the server.

        Parameters
        ----------
        region_name : str
            The name of the region to query status. The default
            value is an empty string, which means that the status
            of all active system shared memory will be returned.
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding system shared memory status.

        Raises
        ------
        InferenceServerException
            If unable to get the status of specified shared memory.

        """
        if region_name != "":
            request_uri = "v2/systemsharedmemory/region/{}/status".format(
                quote(region_name)
            )
        else:
            request_uri = "v2/systemsharedmemory/status"

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def register_system_shared_memory(
        self, name, key, byte_size, offset=0, headers=None, query_params=None
    ):
        """Request the server to register a system shared memory with the
        following specification.

        Parameters
        ----------
        name : str
            The name of the region to register.
        key : str
            The key of the underlying memory object that contains the
            system shared memory region.
        byte_size : int
            The size of the system shared memory region, in bytes.
        offset : int
            Offset, in bytes, within the underlying memory object to
            the start of the system shared memory region. The default
            value is zero.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Raises
        ------
        InferenceServerException
            If unable to register the specified system shared memory.

        """
        request_uri = "v2/systemsharedmemory/region/{}/register".format(quote(name))

        register_request = {"key": key, "offset": offset, "byte_size": byte_size}
        request_body = json.dumps(register_request)

        response = self._post(
            request_uri=request_uri,
            request_body=request_body,
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            print("Registered system shared memory with name '{}'".format(name))

    def unregister_system_shared_memory(self, name="", headers=None, query_params=None):
        """Request the server to unregister a system shared memory with the
        specified name.

        Parameters
        ----------
        name : str
            The name of the region to unregister. The default value is empty
            string which means all the system shared memory regions will be
            unregistered.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Raises
        ------
        InferenceServerException
            If unable to unregister the specified system shared memory region.

        """
        if name != "":
            request_uri = "v2/systemsharedmemory/region/{}/unregister".format(
                quote(name)
            )
        else:
            request_uri = "v2/systemsharedmemory/unregister"

        response = self._post(
            request_uri=request_uri,
            request_body="",
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            if name != "":
                print("Unregistered system shared memory with name '{}'".format(name))
            else:
                print("Unregistered all system shared memory regions")

    def get_cuda_shared_memory_status(
        self, region_name="", headers=None, query_params=None
    ):
        """Request cuda shared memory status from the server.

        Parameters
        ----------
        region_name : str
            The name of the region to query status. The default
            value is an empty string, which means that the status
            of all active cuda shared memory will be returned.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Returns
        -------
        dict
            The JSON dict holding cuda shared memory status.

        Raises
        ------
        InferenceServerException
            If unable to get the status of specified shared memory.

        """
        if region_name != "":
            request_uri = "v2/cudasharedmemory/region/{}/status".format(
                quote(region_name)
            )
        else:
            request_uri = "v2/cudasharedmemory/status"

        response = self._get(
            request_uri=request_uri, headers=headers, query_params=query_params
        )
        _raise_if_error(response)

        content = response.read()
        if self._verbose:
            print(content)

        return json.loads(content)

    def register_cuda_shared_memory(
        self, name, raw_handle, device_id, byte_size, headers=None, query_params=None
    ):
        """Request the server to register a system shared memory with the
        following specification.

        Parameters
        ----------
        name : str
            The name of the region to register.
        raw_handle : bytes
            The raw serialized cudaIPC handle in base64 encoding.
        device_id : int
            The GPU device ID on which the cudaIPC handle was created.
        byte_size : int
            The size of the cuda shared memory region, in bytes.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Raises
        ------
        InferenceServerException
            If unable to register the specified cuda shared memory.

        """
        request_uri = "v2/cudasharedmemory/region/{}/register".format(quote(name))

        register_request = {
            "raw_handle": {"b64": raw_handle},
            "device_id": device_id,
            "byte_size": byte_size,
        }
        request_body = json.dumps(register_request)

        response = self._post(
            request_uri=request_uri,
            request_body=request_body,
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            print("Registered cuda shared memory with name '{}'".format(name))

    def unregister_cuda_shared_memory(self, name="", headers=None, query_params=None):
        """Request the server to unregister a cuda shared memory with the
        specified name.

        Parameters
        ----------
        name : str
            The name of the region to unregister. The default value is empty
            string which means all the cuda shared memory regions will be
            unregistered.
        headers: dict
            Optional dictionary specifying additional
            HTTP headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction

        Raises
        ------
        InferenceServerException
            If unable to unregister the specified cuda shared memory region.

        """
        if name != "":
            request_uri = "v2/cudasharedmemory/region/{}/unregister".format(quote(name))
        else:
            request_uri = "v2/cudasharedmemory/unregister"

        response = self._post(
            request_uri=request_uri,
            request_body="",
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)
        if self._verbose:
            if name != "":
                print("Unregistered cuda shared memory with name '{}'".format(name))
            else:
                print("Unregistered all cuda shared memory regions")

    @staticmethod
    def generate_request_body(
        inputs,
        outputs=None,
        request_id="",
        sequence_id=0,
        sequence_start=False,
        sequence_end=False,
        priority=0,
        timeout=None,
        parameters=None,
    ):
        """Generate a request body for inference using the supplied 'inputs'
        requesting the outputs specified by 'outputs'.

        Parameters
        ----------
        inputs : list
            A list of InferInput objects, each describing data for a input
            tensor required by the model.
        outputs : list
            A list of InferRequestedOutput objects, each describing how the output
            data must be returned. If not specified all outputs produced
            by the model will be returned using default settings.
        request_id: str
            Optional identifier for the request. If specified will be returned
            in the response. Default value is an empty string which means no
            request_id will be used.
        sequence_id : int or str
            The unique identifier for the sequence being represented by the
            object. A value of 0 or "" means that the request does not
            belong to a sequence. Default is 0.
        sequence_start: bool
            Indicates whether the request being added marks the start of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        sequence_end: bool
            Indicates whether the request being added marks the end of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        priority : int
            Indicates the priority of the request. Priority value zero
            indicates that the default priority level should be used
            (i.e. same behavior as not specifying the priority parameter).
            Lower value priorities indicate higher priority levels. Thus
            the highest priority level is indicated by setting the parameter
            to 1, the next highest is 2, etc. If not provided, the server
            will handle the request using default setting for the model.
        timeout : int
            The timeout value for the request, in microseconds. If the request
            cannot be completed within the time the server can take a
            model-specific action such as terminating the request. If not
            provided, the server will handle the request using default setting
            for the model. This option is only respected by the model that is
            configured with dynamic batching. See here for more details:
            https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher
        parameters: dict
            Optional fields to be included in the 'parameters' fields.

        Returns
        -------
        Bytes
            The request body of the inference.
        Int
            The byte size of the inference request header in the request body.
            Returns None if the whole request body constitutes the request header.


        Raises
        ------
        InferenceServerException
            If server fails to perform inference.
        """
        return _get_inference_request(
            inputs=inputs,
            request_id=request_id,
            outputs=outputs,
            sequence_id=sequence_id,
            sequence_start=sequence_start,
            sequence_end=sequence_end,
            priority=priority,
            timeout=timeout,
            custom_parameters=parameters,
        )

    @staticmethod
    def parse_response_body(
        response_body, verbose=False, header_length=None, content_encoding=None
    ):
        """Generate a InferResult object from the given 'response_body'

        Parameters
        ----------
        response_body : bytes
            The inference response from the server
        verbose : bool
            If True generate verbose output. Default value is False.
        header_length : int
            The length of the inference header if the header does not occupy
            the whole response body. Default value is None.
        content_encoding : string
            The encoding of the response body if it is compressed.
            Default value is None.

        Returns
        -------
        InferResult
            The InferResult object generated from the response body
        """
        return InferResult.from_response_body(
            response_body, verbose, header_length, content_encoding
        )

    def infer(
        self,
        model_name,
        inputs,
        model_version="",
        outputs=None,
        request_id="",
        sequence_id=0,
        sequence_start=False,
        sequence_end=False,
        priority=0,
        timeout=None,
        headers=None,
        query_params=None,
        request_compression_algorithm=None,
        response_compression_algorithm=None,
        parameters=None,
    ):
        """Run synchronous inference using the supplied 'inputs' requesting
        the outputs specified by 'outputs'.

        Parameters
        ----------
        model_name: str
            The name of the model to run inference.
        inputs : list
            A list of InferInput objects, each describing data for a input
            tensor required by the model.
        model_version: str
            The version of the model to run inference. The default value
            is an empty string which means then the server will choose
            a version based on the model and internal policy.
        outputs : list
            A list of InferRequestedOutput objects, each describing how the output
            data must be returned. If not specified all outputs produced
            by the model will be returned using default settings.
        request_id: str
            Optional identifier for the request. If specified will be returned
            in the response. Default value is an empty string which means no
            request_id will be used.
        sequence_id : int
            The unique identifier for the sequence being represented by the
            object. Default value is 0 which means that the request does not
            belong to a sequence.
        sequence_start: bool
            Indicates whether the request being added marks the start of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        sequence_end: bool
            Indicates whether the request being added marks the end of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        priority : int
            Indicates the priority of the request. Priority value zero
            indicates that the default priority level should be used
            (i.e. same behavior as not specifying the priority parameter).
            Lower value priorities indicate higher priority levels. Thus
            the highest priority level is indicated by setting the parameter
            to 1, the next highest is 2, etc. If not provided, the server
            will handle the request using default setting for the model.
        timeout : int
            The timeout value for the request, in microseconds. If the request
            cannot be completed within the time the server can take a
            model-specific action such as terminating the request. If not
            provided, the server will handle the request using default setting
            for the model. This option is only respected by the model that is
            configured with dynamic batching. See here for more details:
            https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request.
        query_params: dict
            Optional url query parameters to use in network
            transaction.
        request_compression_algorithm : str
            Optional HTTP compression algorithm to use for the request body on client side.
            Currently supports "deflate", "gzip" and None. By default, no
            compression is used.
        response_compression_algorithm : str
            Optional HTTP compression algorithm to request for the response body.
            Note that the response may not be compressed if the server does not
            support the specified algorithm. Currently supports "deflate",
            "gzip" and None. By default, no compression is requested.
        parameters: dict
            Optional fields to be included in the 'parameters' fields.

        Returns
        -------
        InferResult
            The object holding the result of the inference.

        Raises
        ------
        InferenceServerException
            If server fails to perform inference.
        """

        request_body, json_size = _get_inference_request(
            inputs=inputs,
            request_id=request_id,
            outputs=outputs,
            sequence_id=sequence_id,
            sequence_start=sequence_start,
            sequence_end=sequence_end,
            priority=priority,
            timeout=timeout,
            custom_parameters=parameters,
        )

        if request_compression_algorithm == "gzip":
            if headers is None:
                headers = {}
            headers["Content-Encoding"] = "gzip"
            request_body = gzip.compress(request_body)
        elif request_compression_algorithm == "deflate":
            if headers is None:
                headers = {}
            headers["Content-Encoding"] = "deflate"
            # "Content-Encoding: deflate" actually means compressing in zlib structure
            # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
            request_body = zlib.compress(request_body)

        if response_compression_algorithm == "gzip":
            if headers is None:
                headers = {}
            headers["Accept-Encoding"] = "gzip"
        elif response_compression_algorithm == "deflate":
            if headers is None:
                headers = {}
            headers["Accept-Encoding"] = "deflate"

        if json_size is not None:
            if headers is None:
                headers = {}
            headers["Inference-Header-Content-Length"] = json_size

        if type(model_version) != str:
            raise_error("model version must be a string")
        if model_version != "":
            request_uri = "v2/models/{}/versions/{}/infer".format(
                quote(model_name), model_version
            )
        else:
            request_uri = "v2/models/{}/infer".format(quote(model_name))

        response = self._post(
            request_uri=request_uri,
            request_body=request_body,
            headers=headers,
            query_params=query_params,
        )
        _raise_if_error(response)

        return InferResult(response, self._verbose)

    def async_infer(
        self,
        model_name,
        inputs,
        model_version="",
        outputs=None,
        request_id="",
        sequence_id=0,
        sequence_start=False,
        sequence_end=False,
        priority=0,
        timeout=None,
        headers=None,
        query_params=None,
        request_compression_algorithm=None,
        response_compression_algorithm=None,
        parameters=None,
    ):
        """Run asynchronous inference using the supplied 'inputs' requesting
        the outputs specified by 'outputs'. Even though this call is
        non-blocking, however, the actual number of concurrent requests to
        the server will be limited by the 'concurrency' parameter specified
        while creating this client. In other words, if the inflight
        async_infer exceeds the specified 'concurrency', the delivery of
        the exceeding request(s) to server will be blocked till the slot is
        made available by retrieving the results of previously issued requests.

        Parameters
        ----------
        model_name: str
            The name of the model to run inference.
        inputs : list
            A list of InferInput objects, each describing data for a input
            tensor required by the model.
        model_version: str
            The version of the model to run inference. The default value
            is an empty string which means then the server will choose
            a version based on the model and internal policy.
        outputs : list
            A list of InferRequestedOutput objects, each describing how the output
            data must be returned. If not specified all outputs produced
            by the model will be returned using default settings.
        request_id: str
            Optional identifier for the request. If specified will be returned
            in the response. Default value is 'None' which means no request_id
            will be used.
        sequence_id : int
            The unique identifier for the sequence being represented by the
            object. Default value is 0 which means that the request does not
            belong to a sequence.
        sequence_start: bool
            Indicates whether the request being added marks the start of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        sequence_end: bool
            Indicates whether the request being added marks the end of the
            sequence. Default value is False. This argument is ignored if
            'sequence_id' is 0.
        priority : int
            Indicates the priority of the request. Priority value zero
            indicates that the default priority level should be used
            (i.e. same behavior as not specifying the priority parameter).
            Lower value priorities indicate higher priority levels. Thus
            the highest priority level is indicated by setting the parameter
            to 1, the next highest is 2, etc. If not provided, the server
            will handle the request using default setting for the model.
        timeout : int
            The timeout value for the request, in microseconds. If the request
            cannot be completed within the time the server can take a
            model-specific action such as terminating the request. If not
            provided, the server will handle the request using default setting
            for the model. This option is only respected by the model that is
            configured with dynamic batching. See here for more details:
            https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher
        headers: dict
            Optional dictionary specifying additional HTTP
            headers to include in the request
        query_params: dict
            Optional url query parameters to use in network
            transaction.
        request_compression_algorithm : str
            Optional HTTP compression algorithm to use for the request body on client side.
            Currently supports "deflate", "gzip" and None. By default, no
            compression is used.
        response_compression_algorithm : str
            Optional HTTP compression algorithm to request for the response body.
            Note that the response may not be compressed if the server does not
            support the specified algorithm. Currently supports "deflate",
            "gzip" and None. By default, no compression is requested.
        parameters : dict
            Optional custom parameters to be included in the inference
            request.

        Returns
        -------
        InferAsyncRequest object
            The handle to the asynchronous inference request.

        Raises
        ------
        InferenceServerException
            If server fails to issue inference.
        """

        def wrapped_post(request_uri, request_body, headers, query_params):
            return self._post(request_uri, request_body, headers, query_params)

        request_body, json_size = _get_inference_request(
            inputs=inputs,
            request_id=request_id,
            outputs=outputs,
            sequence_id=sequence_id,
            sequence_start=sequence_start,
            sequence_end=sequence_end,
            priority=priority,
            timeout=timeout,
            custom_parameters=parameters,
        )

        if request_compression_algorithm == "gzip":
            if headers is None:
                headers = {}
            headers["Content-Encoding"] = "gzip"
            request_body = gzip.compress(request_body)
        elif request_compression_algorithm == "deflate":
            if headers is None:
                headers = {}
            headers["Content-Encoding"] = "deflate"
            # "Content-Encoding: deflate" actually means compressing in zlib structure
            # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
            request_body = zlib.compress(request_body)

        if response_compression_algorithm == "gzip":
            if headers is None:
                headers = {}
            headers["Accept-Encoding"] = "gzip"
        elif response_compression_algorithm == "deflate":
            if headers is None:
                headers = {}
            headers["Accept-Encoding"] = "deflate"

        if json_size is not None:
            if headers is None:
                headers = {}
            headers["Inference-Header-Content-Length"] = json_size

        if type(model_version) != str:
            raise_error("model version must be a string")
        if model_version != "":
            request_uri = "v2/models/{}/versions/{}/infer".format(
                quote(model_name), model_version
            )
        else:
            request_uri = "v2/models/{}/infer".format(quote(model_name))

        g = self._pool.apply_async(
            wrapped_post, (request_uri, request_body, headers, query_params)
        )

        # Schedule the greenlet to run in this loop iteration
        g.start()

        # Relinquish control to greenlet loop. Using non-zero
        # value to ensure the control is transferred to the
        # event loop.
        gevent.sleep(0.01)

        if self._verbose:
            verbose_message = "Sent request"
            if request_id != "":
                verbose_message = verbose_message + " '{}'".format(request_id)
            print(verbose_message)

        return InferAsyncRequest(g, self._verbose)
