# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TensorSpec class."""

from typing import Type

import numpy as np

from tensorflow.core.function import trace_type
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util.tf_export import tf_export


# TODO(b/249802365): Sanitize all TensorSpec names.
def sanitize_spec_name(name: str) -> str:
  """Sanitizes Spec names. Matches Graph Node and Python naming conventions.

  Without sanitization, names that are not legal Python parameter names can be
  set which makes it challenging to represent callables supporting the named
  calling capability.

  Args:
    name: The name to sanitize.

  Returns:
    A string that meets Python parameter conventions.
  """
  if not name:
    return "unknown"

  # Lower case and replace non-alphanumeric chars with '_'
  swapped = "".join([c if c.isalnum() else "_" for c in name.lower()])

  if swapped[0].isalpha():
    return swapped
  else:
    return "tensor_" + swapped


class DenseSpec(type_spec.TypeSpec):
  """Describes a dense object with shape, dtype, and name."""

  __slots__ = ["_shape", "_dtype", "_name"]

  _component_specs = property(lambda self: self)

  def __init__(self, shape, dtype=dtypes.float32, name=None):
    """Creates a TensorSpec.

    Args:
      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
      dtype: Value convertible to `tf.DType`. The type of the tensor values.
      name: Optional name for the Tensor.

    Raises:
      TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
        not convertible to a `tf.DType`.
    """
    self._shape = tensor_shape.TensorShape(shape)
    self._dtype = dtypes.as_dtype(dtype)
    self._name = name

  @property
  def shape(self):
    """Returns the `TensorShape` that represents the shape of the tensor."""
    return self._shape

  @property
  def dtype(self):
    """Returns the `dtype` of elements in the tensor."""
    return self._dtype

  @property
  def name(self):
    """Returns the (optionally provided) name of the described tensor."""
    return self._name

  def is_compatible_with(self, spec_or_value):
    return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and
            self._dtype.is_compatible_with(spec_or_value.dtype) and
            self._shape.is_compatible_with(spec_or_value.shape))

  def __repr__(self):
    return "{}(shape={}, dtype={}, name={})".format(
        type(self).__name__, self.shape, repr(self.dtype), repr(self.name))

  def __hash__(self):
    return hash((self._shape, self.dtype))

  def __eq__(self, other):
    # pylint: disable=protected-access
    return (type(self) is type(other) and self._shape == other._shape and
            self._dtype == other._dtype and self._name == other._name)

  def __ne__(self, other):
    return not self == other

  def _serialize(self):
    return (self._shape, self._dtype, self._name)

  def _to_legacy_output_types(self):
    return self._dtype

  def _to_legacy_output_shapes(self):
    return self._shape

  def _to_legacy_output_classes(self):
    return self.value_type


@tf_export("TensorSpec")
@type_spec.register("tf.TensorSpec")
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec,
                 trace_type.Serializable):
  """Describes a tf.Tensor.

  Metadata for describing the `tf.Tensor` objects accepted or returned
  by some TensorFlow APIs.
  """

  __slots__ = []

  @classmethod
  def experimental_type_proto(cls) -> Type[struct_pb2.TensorSpecProto]:
    """Returns the type of proto associated with TensorSpec serialization."""
    return struct_pb2.TensorSpecProto

  @classmethod
  def experimental_from_proto(
      cls, proto: struct_pb2.TensorSpecProto) -> "TensorSpec":
    """Returns a TensorSpec instance based on the serialized proto."""
    return TensorSpec(
        shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
        dtype=proto.dtype,
        name=proto.name if proto.name else None)

  def experimental_as_proto(self) -> struct_pb2.TensorSpecProto:
    """Returns a proto representation of the TensorSpec instance."""
    return struct_pb2.TensorSpecProto(
        shape=self.shape.experimental_as_proto(),
        dtype=self.dtype.experimental_as_proto().datatype,
        name=self.name)

  def is_compatible_with(self, spec_or_tensor):  # pylint:disable=useless-super-delegation
    """Returns True if spec_or_tensor is compatible with this TensorSpec.

    Two tensors are considered compatible if they have the same dtype
    and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).

    Args:
      spec_or_tensor: A tf.TensorSpec or a tf.Tensor

    Returns:
      True if spec_or_tensor is compatible with self.
    """
    return super(TensorSpec, self).is_compatible_with(spec_or_tensor)

  @classmethod
  def from_spec(cls, spec, name=None):
    """Returns a `TensorSpec` with the same shape and dtype as `spec`.

    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
    >>> tf.TensorSpec.from_spec(spec, "NewName")
    TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')

    Args:
      spec: The `TypeSpec` used to create the new `TensorSpec`.
      name: The name for the new `TensorSpec`.  Defaults to `spec.name`.
    """
    return cls(spec.shape, spec.dtype, name or spec.name)

  @classmethod
  def from_tensor(cls, tensor, name=None):
    """Returns a `TensorSpec` that describes `tensor`.

    >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
    TensorSpec(shape=(3,), dtype=tf.int32, name=None)

    Args:
      tensor: The `tf.Tensor` that should be described.
      name: A name for the `TensorSpec`.  Defaults to `tensor.op.name`.

    Returns:
      A `TensorSpec` that describes `tensor`.
    """
    if isinstance(tensor, ops.EagerTensor):
      return TensorSpec(tensor.shape, tensor.dtype, name)
    elif isinstance(tensor, ops.Tensor):
      # TODO(b/249802365): Return a sanitized version of op name or no name.
      return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
    else:
      raise ValueError(
          f"`tensor` should be a tf.Tensor, but got type {type(tensor)}.")

  @property
  def value_type(self):
    """The Python type for values that are compatible with this TypeSpec."""
    return ops.Tensor

  def _to_components(self, value):
    try:
      value = ops.convert_to_tensor(value, self._dtype)
    except (TypeError, ValueError):
      raise ValueError(f"Value {value} is not convertible to a tensor with "
                       f"dtype {self._dtype} and shape {self._shape}.")
    if not value.shape.is_compatible_with(self._shape):
      raise ValueError(f"Value {value} is not convertible to a tensor with "
                       f"dtype {self._dtype} and shape {self._shape}.")
    return value

  def _from_components(self, components):
    return components

  def _from_compatible_tensor_list(self, tensor_list):
    # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
    # op here and return that, instead of mutating the input's shape using
    # `Tensor.set_shape()`. However, that would add extra ops, which could
    # impact performance. When this bug is resolved, we should be able to add
    # the `ensure_shape()` ops and optimize them away using contextual shape
    # information.
    assert len(tensor_list) == 1
    tensor_list[0].set_shape(self._shape)
    return tensor_list[0]

  def _to_batchable_tensor_list(self, value, batched=False):
    if batched and self._shape.merge_with(value.shape).ndims == 0:
      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
    return self._to_components(value)

  def _batch(self, batch_size):
    return TensorSpec(
        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
        self._dtype)

  def _unbatch(self):
    if self._shape.ndims == 0:
      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
    return TensorSpec(self._shape[1:], self._dtype)

  @property
  def _flat_tensor_specs(self):
    return [self]

  def _to_tensor_list(self, value):
    return [self._to_components(value)]

  def _to_batched_tensor_list(self, value):
    return self._to_tensor_list(value)

  # TODO(b/206014848): Helper function to support logic that does not consider
  # Tensor name. Will be removed once load-bearing usages of Tensor name are
  # fixed.
  def _without_tensor_names(self) -> "TensorSpec":
    """Returns a version of `TensorSpec` with the name removed."""
    if self.name is None:
      return self
    else:
      return TensorSpec(self.shape, self.dtype)

trace_type.register_serializable(TensorSpec)


# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
@type_spec.register("tf.BoundedTensorSpec")
class BoundedTensorSpec(TensorSpec, trace_type.Serializable):
  """A `TensorSpec` that specifies minimum and maximum values.

  Example usage:
  ```python
  spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
  tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
  tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
  ```

  Bounds are meant to be inclusive. This is especially important for
  integer types. The following spec will be satisfied by tensors
  with values in the set {0, 1, 2}:
  ```python
  spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
  ```
  """

  __slots__ = ("_minimum", "_maximum")

  def __init__(self, shape, dtype, minimum, maximum, name=None):
    """Initializes a new `BoundedTensorSpec`.

    Args:
      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
      dtype: Value convertible to `tf.DType`. The type of the tensor values.
      minimum: Number or sequence specifying the minimum element bounds
        (inclusive). Must be broadcastable to `shape`.
      maximum: Number or sequence specifying the maximum element bounds
        (inclusive). Must be broadcastable to `shape`.
      name: Optional string containing a semantic name for the corresponding
        array. Defaults to `None`.

    Raises:
      ValueError: If `minimum` or `maximum` are not provided or not
        broadcastable to `shape`.
      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
        numpy dtype.
    """
    super(BoundedTensorSpec, self).__init__(shape, dtype, name)

    if minimum is None:
      raise ValueError("`minimum` can not be None.")
    if maximum is None:
      raise ValueError("`maximum` can not be None.")

    try:
      minimum_shape = np.shape(minimum)
      common_shapes.broadcast_shape(
          tensor_shape.TensorShape(minimum_shape), self.shape)
    except ValueError as exception:
      raise ValueError(f"`minimum` {minimum} is not compatible with shape "
                       f"{self.shape}. Original error: {exception!r}.")

    try:
      maximum_shape = np.shape(maximum)
      common_shapes.broadcast_shape(
          tensor_shape.TensorShape(maximum_shape), self.shape)
    except ValueError as exception:
      raise ValueError(f"`maximum` {maximum} is not compatible with shape "
                       f"{self.shape}. Original error: {exception!r}.")

    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype)
    self._minimum.setflags(write=False)

    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype)
    self._maximum.setflags(write=False)

  @classmethod
  def experimental_type_proto(cls) -> Type[struct_pb2.BoundedTensorSpecProto]:
    """Returns the type of proto associated with BoundedTensorSpec serialization."""
    return struct_pb2.BoundedTensorSpecProto

  @classmethod
  def experimental_from_proto(
      cls, proto: struct_pb2.BoundedTensorSpecProto) -> "BoundedTensorSpec":
    """Returns a BoundedTensorSpec instance based on the serialized proto."""
    return BoundedTensorSpec(
        shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
        dtype=proto.dtype,
        minimum=tensor_util.MakeNdarray(proto.minimum),
        maximum=tensor_util.MakeNdarray(proto.maximum),
        name=proto.name if proto.name else None)

  def experimental_as_proto(self) -> struct_pb2.BoundedTensorSpecProto:
    """Returns a proto representation of the BoundedTensorSpec instance."""
    return struct_pb2.BoundedTensorSpecProto(
        shape=self.shape.experimental_as_proto(),
        dtype=self.dtype.experimental_as_proto().datatype,
        minimum=tensor_util.make_tensor_proto(self._minimum),
        maximum=tensor_util.make_tensor_proto(self._maximum),
        name=self.name)

  @classmethod
  def from_spec(cls, spec):
    """Returns a `TensorSpec` with the same shape and dtype as `spec`.

    If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
    `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
    `spec.dtype.min` and `spec.dtype.max`.

    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
    >>> BoundedTensorSpec.from_spec(spec)
    BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
        minimum=array(-2147483648, dtype=int32),
        maximum=array(2147483647, dtype=int32))

    Args:
      spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
    """
    dtype = dtypes.as_dtype(spec.dtype)
    minimum = getattr(spec, "minimum", dtype.min)
    maximum = getattr(spec, "maximum", dtype.max)
    return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)

  @property
  def minimum(self):
    """Returns a NumPy array specifying the minimum bounds (inclusive)."""
    return self._minimum

  @property
  def maximum(self):
    """Returns a NumPy array specifying the maximum bounds (inclusive)."""
    return self._maximum

  def __repr__(self):
    s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
    return s.format(self.shape, repr(self.dtype), repr(self.name),
                    repr(self.minimum), repr(self.maximum))

  def __eq__(self, other):
    tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
    return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
            np.allclose(self.maximum, other.maximum))

  def __hash__(self):
    return hash((self._shape, self.dtype))

  def __reduce__(self):
    return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
                               self._maximum, self._name)

  def _serialize(self):
    return (self._shape, self._dtype, self._minimum, self._maximum, self._name)

trace_type.register_serializable(BoundedTensorSpec)
_pywrap_utils.RegisterType("TensorSpec", TensorSpec)

# Note: we do not include Tensor names when constructing TypeSpecs.
type_spec.register_type_spec_from_value_converter(
    ops.Tensor, lambda tensor: TensorSpec(tensor.shape, tensor.dtype))

type_spec.register_type_spec_from_value_converter(
    np.ndarray, lambda array: TensorSpec(array.shape, array.dtype))
