From 76d4c1ad7eee8bda2daee6d6a3e388ab58ca8a1a Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 17 May 2024 19:14:25 +0400 Subject: [PATCH 01/19] [POC][OV] Support OpenVINO as Keras 3 backend Signed-off-by: Kazantsev, Roman --- keras/src/backend/__init__.py | 4 + keras/src/backend/exports.py | 5 + keras/src/backend/openvino/__init__.py | 21 + keras/src/backend/openvino/core.py | 258 ++++++ keras/src/backend/openvino/image.py | 385 +++++++++ keras/src/backend/openvino/layer.py | 2 + keras/src/backend/openvino/linalg.py | 88 ++ keras/src/backend/openvino/math.py | 329 +++++++ keras/src/backend/openvino/nn.py | 960 +++++++++++++++++++++ keras/src/backend/openvino/numpy.py | 1087 ++++++++++++++++++++++++ keras/src/backend/openvino/random.py | 120 +++ keras/src/backend/openvino/rnn.py | 239 ++++++ keras/src/backend/openvino/trainer.py | 322 +++++++ keras/src/layers/layer.py | 5 + keras/src/models/functional.py | 10 + keras/src/models/model.py | 2 + keras/src/ops/function.py | 35 + keras/src/utils/backend_utils.py | 3 + 18 files changed, 3875 insertions(+) create mode 100644 keras/src/backend/openvino/__init__.py create mode 100644 keras/src/backend/openvino/core.py create mode 100644 keras/src/backend/openvino/image.py create mode 100644 keras/src/backend/openvino/layer.py create mode 100644 keras/src/backend/openvino/linalg.py create mode 100644 keras/src/backend/openvino/math.py create mode 100644 keras/src/backend/openvino/nn.py create mode 100644 keras/src/backend/openvino/numpy.py create mode 100644 keras/src/backend/openvino/random.py create mode 100644 keras/src/backend/openvino/rnn.py create mode 100644 keras/src/backend/openvino/trainer.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 5c7fa223520..527524699d1 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -40,6 +40,10 @@ elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 + distribution_lib = None +elif backend() == "openvino": + from keras.src.backend.openvino import * # noqa: F403 + distribution_lib = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/exports.py b/keras/src/backend/exports.py index 54ee1c74bb8..0b2a0447c87 100644 --- a/keras/src/backend/exports.py +++ b/keras/src/backend/exports.py @@ -15,6 +15,11 @@ BackendVariable = NumpyVariable backend_name_scope = backend.common.name_scope.name_scope +elif backend.backend() == "openvino": + from keras.src.backend.openvino.core import Variable as OpenVINOVariable + + BackendVariable = OpenVINOVariable + backend_name_scope = backend.common.name_scope.name_scope else: raise RuntimeError(f"Invalid backend: {backend.backend()}") diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py new file mode 100644 index 00000000000..c54a33c809a --- /dev/null +++ b/keras/src/backend/openvino/__init__.py @@ -0,0 +1,21 @@ +from keras.src.backend.openvino import core +from keras.src.backend.openvino import image +from keras.src.backend.openvino import linalg +from keras.src.backend.openvino import math +from keras.src.backend.openvino import nn +from keras.src.backend.openvino import numpy +from keras.src.backend.openvino import random +from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.openvino.core import Variable +from keras.src.backend.openvino.core import cast +from keras.src.backend.openvino.core import compute_output_spec +from keras.src.backend.openvino.core import cond +from keras.src.backend.openvino.core import convert_to_numpy +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import is_tensor +from keras.src.backend.openvino.core import shape +from keras.src.backend.openvino.core import vectorized_map +from keras.src.backend.openvino.rnn import cudnn_ok +from keras.src.backend.openvino.rnn import gru +from keras.src.backend.openvino.rnn import lstm +from keras.src.backend.openvino.rnn import rnn diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py new file mode 100644 index 00000000000..e1d249508a3 --- /dev/null +++ b/keras/src/backend/openvino/core.py @@ -0,0 +1,258 @@ +import contextlib + +import numpy as np + +from keras.src.backend.common import global_state +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.stateless_scope import StatelessScope +import openvino as ov + +SUPPORTS_SPARSE_TENSORS = False + +OPENVINO_DTYPES = { + "float16": ov.Type.f16, + "float32": ov.Type.f32, + "float64": ov.Type.f64, + "uint8": ov.Type.u8, + "uint16": ov.Type.u16, + "uint32": ov.Type.u32, + "int8": ov.Type.i8, + "int16": ov.Type.i16, + "int32": ov.Type.i32, + "int64": ov.Type.i64, + "bfloat16": ov.Type.bf16, + "bool": ov.Type.boolean, + "float8_e4m3fn": ov.Type.f8e4m3, + "float8_e5m2": ov.Type.f8e5m2, +} + + +def ov_to_keras_type(ov_type): + for _keras_type, _ov_type in OPENVINO_DTYPES.items(): + if ov_type == _ov_type: + return _keras_type + raise ValueError( + f"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'" + ) + +@contextlib.contextmanager +def device_scope(device_name): + current_device = _parse_device_input(device_name) + global_state.set_global_attribute("openvino_device", current_device) + + +def get_device(): + device = global_state.get_global_attribute("openvino_device", None) + if device is None: + return "CPU" + return device + + +def _parse_device_input(device_name): + if isinstance(device_name, str): + # We support string value like "cpu:0", "gpu:1", and need to convert + # "gpu" to "cuda" + device_name = device_name.upper() + device_type, _ = device_name.split(":") + return device_type + else: + raise ValueError( + "Invalid value for argument `device_name`. " + "Expected a string like 'gpu:0' or 'cpu'. " + f"Received: device_name='{device_name}'" + ) + return device_name + + +class Variable(KerasVariable): + def _initialize(self, value): + self._value = np.array(value, dtype=self._dtype) + + def _direct_assign(self, value): + self._value = np.array(value, dtype=self._dtype) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + # Overload native accessor. + def __array__(self): + return self.value + + +def convert_to_tensor(x, dtype=None, sparse=None): + if sparse: + raise ValueError("`sparse=True` is not supported with numpy backend") + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, Variable): + if dtype and dtype != x.dtype: + return x.value.astype(dtype) + return x.value + if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16": + # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset). + # Instead we convert "as is" (to stored dtype) and cast. + return np.asarray(x).astype(dtype) + if dtype is None: + dtype = result_type( + *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] + ) + return np.array(x, dtype=dtype) + + +def convert_to_numpy(x): + return np.array(x) + + +def is_tensor(x): + if isinstance(x, (np.generic, np.ndarray)): + return True + return False + + +def shape(x): + return x.shape + + +def cast(x, dtype): + raise NotImplementedError( + "`cast` is not supported with openvino backend" + ) + + +def cond(pred, true_fn, false_fn): + raise NotImplementedError( + "`cond` is not supported with openvino backend" + ) + + +def vectorized_map(function, elements): + raise NotImplementedError( + "`vectorized_map` is not supported with openvino backend" + ) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(): + + def has_none_shape(x): + if isinstance(x, KerasTensor): + return None in x.shape + return False + + none_in_shape = any(map(has_none_shape, tree.flatten((args, kwargs)))) + + def convert_keras_tensor_to_numpy(x, fill_value=None): + if isinstance(x, KerasTensor): + shape = list(x.shape) + if fill_value: + for i, e in enumerate(shape): + if e is None: + shape[i] = fill_value + return np.empty( + shape=shape, + dtype=x.dtype, + ) + return x + + args_1, kwargs_1 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=83), + (args, kwargs), + ) + outputs_1 = fn(*args_1, **kwargs_1) + + outputs = outputs_1 + + if none_in_shape: + args_2, kwargs_2 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=89), + (args, kwargs), + ) + outputs_2 = fn(*args_2, **kwargs_2) + + flat_out_1 = tree.flatten(outputs_1) + flat_out_2 = tree.flatten(outputs_2) + + flat_out = [] + for x1, x2 in zip(flat_out_1, flat_out_2): + shape = list(x1.shape) + for i, e in enumerate(x2.shape): + if e != shape[i]: + shape[i] = None + flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype))) + outputs = tree.pack_sequence_as(outputs_1, flat_out) + + def convert_numpy_to_keras_tensor(x): + if is_tensor(x): + return KerasTensor(x.shape, standardize_dtype(x.dtype)) + return x + + output_spec = tree.map_structure(convert_numpy_to_keras_tensor, outputs) + return output_spec + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + raise NotImplementedError( + "`scan` is not supported with openvino backend" + ) + + +def scatter(indices, values, shape): + raise NotImplementedError( + "`scatter` is not supported with openvino backend" + ) + + +def scatter_update(inputs, indices, updates): + raise NotImplementedError( + "`scatter_update` is not supported with openvino backend" + ) + + +def slice(inputs, start_indices, lengths): + raise NotImplementedError( + "`slice` is not supported with openvino backend" + ) + + +def slice_update(inputs, start_indices, updates): + raise NotImplementedError( + "`slice_update` is not supported with openvino backend" + ) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + raise NotImplementedError( + "`while_loop` is not supported with openvino backend" + ) + + +def fori_loop(lower, upper, body_fun, init_val): + raise NotImplementedError( + "`fori_loop` is not supported with openvino backend" + ) + + +def stop_gradient(x): + return x + + +def unstack(x, num=None, axis=0): + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) + + +def custom_gradient(fun): + raise NotImplementedError( + "`custom_gradient` is not supported with numpy backend" + ) diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py new file mode 100644 index 00000000000..d385da3d834 --- /dev/null +++ b/keras/src/backend/openvino/image.py @@ -0,0 +1,385 @@ +import jax +import numpy as np + +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.utils.module_utils import scipy + +RESIZE_INTERPOLATIONS = ( + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", +) + + +def rgb_to_grayscale(image, data_format="channels_last"): + if data_format == "channels_first": + if len(image.shape) == 4: + image = np.transpose(image, (0, 2, 3, 1)) + elif len(image.shape) == 3: + image = np.transpose(image, (1, 2, 0)) + else: + raise ValueError( + "Invalid input rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"image.shape={image.shape}" + ) + red, green, blue = image[..., 0], image[..., 1], image[..., 2] + grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue + grayscale_image = np.expand_dims(grayscale_image, axis=-1) + if data_format == "channels_first": + if len(image.shape) == 4: + grayscale_image = np.transpose(grayscale_image, (0, 3, 1, 2)) + elif len(image.shape) == 3: + grayscale_image = np.transpose(grayscale_image, (2, 0, 1)) + return np.array(grayscale_image) + + +def resize( + image, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format="channels_last", +): + if interpolation not in RESIZE_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" + ) + if fill_mode != "constant": + raise ValueError( + "Invalid value for argument `fill_mode`. Only `'constant'` " + f"is supported. Received: fill_mode={fill_mode}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + target_height, target_width = size + if len(image.shape) == 4: + if data_format == "channels_last": + size = (image.shape[0],) + size + (image.shape[-1],) + else: + size = (image.shape[0], image.shape[1]) + size + elif len(image.shape) == 3: + if data_format == "channels_last": + size = size + (image.shape[-1],) + else: + size = (image.shape[0],) + size + else: + raise ValueError( + "Invalid input rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"image.shape={image.shape}" + ) + + if crop_to_aspect_ratio: + shape = image.shape + if data_format == "channels_last": + height, width = shape[-3], shape[-2] + else: + height, width = shape[-2], shape[-1] + crop_height = int(float(width * target_height) / target_width) + crop_height = min(height, crop_height) + crop_width = int(float(height * target_width) / target_height) + crop_width = min(width, crop_width) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + if data_format == "channels_last": + if len(image.shape) == 4: + image = image[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + image = image[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(image.shape) == 4: + image = image[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + image = image[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + elif pad_to_aspect_ratio: + shape = image.shape + batch_size = image.shape[0] + if data_format == "channels_last": + height, width, channels = shape[-3], shape[-2], shape[-1] + else: + channels, height, width = shape[-3], shape[-2], shape[-1] + pad_height = int(float(width * target_height) / target_width) + pad_height = max(height, pad_height) + pad_width = int(float(height * target_width) / target_height) + pad_width = max(width, pad_width) + img_box_hstart = int(float(pad_height - height) / 2) + img_box_wstart = int(float(pad_width - width) / 2) + if data_format == "channels_last": + if len(image.shape) == 4: + padded_img = ( + np.ones( + ( + batch_size, + pad_height + height, + pad_width + width, + channels, + ), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + :, + ] = image + else: + padded_img = ( + np.ones( + (pad_height + height, pad_width + width, channels), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + :, + ] = image + else: + if len(image.shape) == 4: + padded_img = ( + np.ones( + ( + batch_size, + channels, + pad_height + height, + pad_width + width, + ), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + ] = image + else: + padded_img = ( + np.ones( + (channels, pad_height + height, pad_width + width), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + ] = image + image = padded_img + + return np.array( + jax.image.resize(image, size, method=interpolation, antialias=antialias) + ) + + +AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order + "nearest": 0, + "bilinear": 1, +} +AFFINE_TRANSFORM_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def affine_transform( + image, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format="channels_last", +): + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + + transform = convert_to_tensor(transform) + + if len(image.shape) not in (3, 4): + raise ValueError( + "Invalid image rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"image.shape={image.shape}" + ) + if len(transform.shape) not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + + # scipy.ndimage.map_coordinates lacks support for half precision. + input_dtype = image.dtype + if input_dtype == "float16": + image = image.astype("float32") + + # unbatched case + need_squeeze = False + if len(image.shape) == 3: + image = np.expand_dims(image, axis=0) + need_squeeze = True + if len(transform.shape) == 1: + transform = np.expand_dims(transform, axis=0) + + if data_format == "channels_first": + image = np.transpose(image, (0, 2, 3, 1)) + + batch_size = image.shape[0] + + # get indices + meshgrid = np.meshgrid( + *[np.arange(size) for size in image.shape[1:]], indexing="ij" + ) + indices = np.concatenate( + [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 + ) + indices = np.tile(indices, (batch_size, 1, 1, 1, 1)) + + # swap the values + a0 = transform[:, 0].copy() + a2 = transform[:, 2].copy() + b1 = transform[:, 4].copy() + b2 = transform[:, 5].copy() + transform[:, 0] = b1 + transform[:, 2] = b2 + transform[:, 4] = a0 + transform[:, 5] = a2 + + # deal with transform + transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1) + transform = np.reshape(transform, (batch_size, 3, 3)) + offset = transform[:, 0:2, 2].copy() + offset = np.pad(offset, pad_width=[[0, 0], [0, 1]]) + transform[:, 0:2, 2] = 0 + + # transform the indices + coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = np.moveaxis(coordinates, source=-1, destination=1) + coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1)) + + # apply affine transformation + affined = np.stack( + [ + map_coordinates( + image[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for i in range(batch_size) + ], + axis=0, + ) + + if data_format == "channels_first": + affined = np.transpose(affined, (0, 3, 1, 2)) + if need_squeeze: + affined = np.squeeze(affined, axis=0) + if input_dtype == "float16": + affined = affined.astype(input_dtype) + return affined + + +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: order={order}" + ) + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result diff --git a/keras/src/backend/openvino/layer.py b/keras/src/backend/openvino/layer.py new file mode 100644 index 00000000000..08b761f972e --- /dev/null +++ b/keras/src/backend/openvino/layer.py @@ -0,0 +1,2 @@ +class NumpyLayer: + pass diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py new file mode 100644 index 00000000000..c925f3fbee0 --- /dev/null +++ b/keras/src/backend/openvino/linalg.py @@ -0,0 +1,88 @@ +import numpy as np +import scipy.linalg as sl + +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.openvino.core import convert_to_tensor + + +def cholesky(a): + return np.linalg.cholesky(a) + + +def det(a): + return np.linalg.det(a) + + +def eig(a): + return np.linalg.eig(a) + + +def eigh(a): + return np.linalg.eigh(a) + + +def inv(a): + return np.linalg.inv(a) + + +def lu_factor(a): + if a.ndim == 2: + return sl.lu_factor(a) + + m, n = a.shape[-2:] + signature = "(m,n) -> (m,n), " + signature += "(m)" if m <= n else "(n)" + _lu_factor_gufunc = np.vectorize( + sl.lu_factor, + signature=signature, + ) + return _lu_factor_gufunc(a) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = dtypes.result_type(x.dtype, "float32") + return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( + dtype + ) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return np.linalg.qr(x, mode=mode) + + +def solve(a, b): + return np.linalg.solve(a, b) + + +def solve_triangular(a, b, lower=False): + if a.ndim == 2: + return sl.solve_triangular(a, b, lower=lower) + + _vectorized_solve_triangular = np.vectorize( + lambda a, b: sl.solve_triangular(a, b, lower=lower), + signature="(n,n),(n,m)->(n,m)", + ) + if b.ndim == a.ndim - 1: + b = np.expand_dims(b, axis=-1) + return _vectorized_solve_triangular(a, b).squeeze(axis=-1) + return _vectorized_solve_triangular(a, b) + + +def svd(x, full_matrices=True, compute_uv=True): + return np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + + +def lstsq(a, b, rcond=None): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return np.linalg.lstsq(a, b, rcond=rcond)[0] diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py new file mode 100644 index 00000000000..81e5d71639d --- /dev/null +++ b/keras/src/backend/openvino/math.py @@ -0,0 +1,329 @@ +import numpy as np + +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.jax.math import fft as jax_fft +from keras.src.backend.jax.math import fft2 as jax_fft2 +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.utils.module_utils import scipy + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + if num_segments is None: + num_segments = np.amax(segment_ids) + 1 + + valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 + valid_data = data[valid_indices] + valid_segment_ids = segment_ids[valid_indices] + + data_shape = list(valid_data.shape) + data_shape[0] = ( + num_segments # Replace first dimension (which corresponds to segments) + ) + + if sorted: + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.add.at(result, valid_segment_ids, valid_data) + else: + sort_indices = np.argsort(valid_segment_ids) + sorted_segment_ids = valid_segment_ids[sort_indices] + sorted_data = valid_data[sort_indices] + + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.add.at(result, sorted_segment_ids, sorted_data) + + return result + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + if num_segments is None: + num_segments = np.amax(segment_ids) + 1 + + valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 + valid_data = data[valid_indices] + valid_segment_ids = segment_ids[valid_indices] + + data_shape = list(valid_data.shape) + data_shape[0] = ( + num_segments # Replace first dimension (which corresponds to segments) + ) + + if sorted: + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.maximum.at(result, valid_segment_ids, valid_data) + else: + sort_indices = np.argsort(valid_segment_ids) + sorted_segment_ids = valid_segment_ids[sort_indices] + sorted_data = valid_data[sort_indices] + + result = np.zeros(data_shape, dtype=valid_data.dtype) + np.maximum.at(result, sorted_segment_ids, sorted_data) + + return result + + +def top_k(x, k, sorted=False): + sorted_indices = np.argsort(x, axis=-1)[..., ::-1] + sorted_values = np.sort(x, axis=-1)[..., ::-1] + + if sorted: + # Take the k largest values. + top_k_values = sorted_values[..., :k] + top_k_indices = sorted_indices[..., :k] + else: + # Partition the array such that all values larger than the k-th + # largest value are to the right of it. + top_k_values = np.partition(x, -k, axis=-1)[..., -k:] + top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:] + + # Get the indices in sorted order. + idx = np.argsort(-top_k_values, axis=-1) + + # Get the top k values and their indices. + top_k_values = np.take_along_axis(top_k_values, idx, axis=-1) + top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1) + + return top_k_values, top_k_indices + + +def in_top_k(targets, predictions, k): + targets = targets[:, None] + topk_values = top_k(predictions, k)[0] + targets_values = np.take_along_axis(predictions, targets, axis=-1) + mask = targets_values >= topk_values + return np.any(mask, axis=-1) + + +def logsumexp(x, axis=None, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x + return np.squeeze(result) if not keepdims else result + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return np.linalg.qr(x, mode=mode) + + +def extract_sequences(x, sequence_length, sequence_stride): + *batch_shape, _ = x.shape + batch_shape = list(batch_shape) + shape = x.shape[:-1] + ( + (x.shape[-1] - (sequence_length - sequence_stride)) // sequence_stride, + sequence_length, + ) + strides = x.strides[:-1] + ( + sequence_stride * x.strides[-1], + x.strides[-1], + ) + x = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return np.reshape(x, (*batch_shape, *x.shape[-2:])) + + +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + # `convert_to_tensor` does not support passing complex tensors. We separate + # the input out into real and imaginary and convert them separately. + real, imag = x + # Check shapes. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not np.issubdtype(real.dtype, np.floating) or not np.issubdtype( + imag.dtype, np.floating + ): + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + complex_input = real + 1j * imag + return complex_input + + +def fft(x): + real, imag = jax_fft(x) + return np.array(real), np.array(imag) + + +def fft2(x): + real, imag = jax_fft2(x) + return np.array(real), np.array(imag) + + +def rfft(x, fft_length=None): + complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward") + # numpy always outputs complex128, so we need to recast the dtype + return ( + np.real(complex_output).astype(x.dtype), + np.imag(complex_output).astype(x.dtype), + ) + + +def irfft(x, fft_length=None): + complex_input = _get_complex_tensor_from_tuple(x) + # numpy always outputs float64, so we need to recast the dtype + return np.fft.irfft( + complex_input, n=fft_length, axis=-1, norm="backward" + ).astype(x[0].dtype) + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + if standardize_dtype(x.dtype) not in {"float32", "float64"}: + raise TypeError( + "Invalid input type. Expected `float32` or `float64`. " + f"Received: input type={x.dtype}" + ) + if fft_length < sequence_length: + raise ValueError( + "`fft_length` must equal or larger than `sequence_length`. " + f"Received: sequence_length={sequence_length}, " + f"fft_length={fft_length}" + ) + if isinstance(window, str): + if window not in {"hann", "hamming"}: + raise ValueError( + "If a string is passed to `window`, it must be one of " + f'`"hann"`, `"hamming"`. Received: window={window}' + ) + x = convert_to_tensor(x) + ori_dtype = x.dtype + + if center: + pad_width = [(0, 0) for _ in range(len(x.shape))] + pad_width[-1] = (fft_length // 2, fft_length // 2) + x = np.pad(x, pad_width, mode="reflect") + + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=x.dtype + ) + else: + win = convert_to_tensor(window, dtype=x.dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = np.pad(win, [[l_pad, r_pad]]) + else: + win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype) + + x = scipy.signal.stft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=None, + padded=False, + )[-1] + + # scale and swap to (..., num_sequences, fft_bins) + x = x / np.sqrt(1.0 / win.sum() ** 2) + x = np.swapaxes(x, -2, -1) + return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype) + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + x = _get_complex_tensor_from_tuple(x) + dtype = np.real(x).dtype + + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=dtype + ) + else: + win = convert_to_tensor(window, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = np.pad(win, [[l_pad, r_pad]]) + else: + win = np.ones((sequence_length + l_pad + r_pad), dtype=dtype) + + x = scipy.signal.istft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=False, + time_axis=-2, + freq_axis=-1, + )[-1] + + # scale + x = x / win.sum() if window is not None else x / sequence_stride + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center is True: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def rsqrt(x): + return 1.0 / np.sqrt(x) + + +def erf(x): + return np.array(scipy.special.erf(x)) + + +def erfinv(x): + return np.array(scipy.special.erfinv(x)) + + +def solve(a, b): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return np.linalg.solve(a, b) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = dtypes.result_type(x.dtype, "float32") + return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( + dtype + ) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py new file mode 100644 index 00000000000..0270bc0492e --- /dev/null +++ b/keras/src/backend/openvino/nn.py @@ -0,0 +1,960 @@ +import jax +import numpy as np +from jax import lax + +from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_jax, +) +from keras.src.backend.openvino.core import cast +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import is_tensor + + +def relu(x): + from openvino.runtime.opset14 import relu + return relu(x) + + +def relu6(x): + raise NotImplementedError( + "`relu6` is not supported with openvino backend" + ) + + +def sigmoid(x): + from openvino.runtime.opset14 import sigmoid + return sigmoid(x) + + +def tanh(x): + from openvino.runtime.opset14 import tanh + return tanh(x) + + +def softplus(x): + from openvino.runtime.opset14 import softplus + return softplus(x) + + +def softsign(x): + from openvino.runtime.opset14 import softsign + return softsign(x) + + +def silu(x): + from openvino.runtime.opset14 import sigmoid, multiply + return multiply(x, sigmoid(x)) + + +def log_sigmoid(x): + from openvino.runtime.opset14 import softplus, negative + return negative(softplus(negative(x))) + + +def leaky_relu(x, negative_slope=0.2): + x = convert_to_tensor(x) + return np.maximum(x, np.array(negative_slope, x.dtype) * x) + + +def hard_sigmoid(x): + from openvino.runtime.opset14 import hard_sigmoid + alpha = 1 / np.array(6.0, dtype=np.float32) + beta = np.array(0.5, dtype=np.float32) + return hard_sigmoid(x, alpha, beta) + + +def hard_silu(x): + from openvino.runtime.opset14 import multiply + return multiply(x, hard_sigmoid(x)) + + +def elu(x, alpha=1.0): + from openvino.runtime.opset14 import elu + return elu(x, alpha) + + +def selu( + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, +): + from openvino.runtime.opset14 import selu + return selu(x, alpha, scale) + + +def gelu(x, approximate=True): + from openvino.runtime.opset14 import gelu + approximate_mode = "erf" + if approximate: + approximate_mode = "tanh" + return gelu(x, approximate_mode) + + +def softmax(x, axis=None): + from openvino.runtime.opset14 import softmax + return softmax(x, axis) + + +def log_softmax(x, axis=None): + from openvino.runtime.opset14 import log_softmax + return log_softmax(x, axis) + + +def _convert_to_spatial_operand( + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, +): + # Helper function that converts an operand to a spatial operand. + x = (x,) * num_spatial_dims if isinstance(x, int) else x + if not include_batch_and_channels: + return x + if data_format == "channels_last": + x = (1,) + x + (1,) + else: + x = (1,) + (1,) + x + return x + + +def _pool( + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", +): + """Helper function to define pooling functions. + + Args: + inputs: input data of shape `N+2`. + initial_value: the initial value for the reduction. + reduce_fn: a reduce function of the form `(T, T) -> T`. + pool_size: a sequence of `N` integers, representing the window size to + reduce over. + strides: a sequence of `N` integers, representing the inter-window + strides (default: `(1, ..., 1)`). + padding: either the string `same` or `valid`. + + Returns: + The output of the reduction for each window slice. + """ + if padding not in ("same", "valid"): + raise ValueError( + f"Invalid padding '{padding}', must be 'same' or 'valid'." + ) + padding = padding.upper() + return np.array( + lax.reduce_window( + inputs, + initial_value, + reduce_fn, + pool_size, + strides, + padding, + ) + ) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) + + +def average_pool( + inputs, + pool_size, + strides, + padding, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + + pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) + if padding == "valid": + # Avoid the extra reduce_window. + return pooled / np.prod(pool_size) + else: + # Count the number of valid entries at each input point, then use that + # for computing average. Assumes that any two arrays of same shape will + # be padded the same. Avoid broadcasting on axis where pooling is + # skipped. + shape = [ + (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) + ] + window_counts = _pool( + np.ones(shape, inputs.dtype), + 0.0, + lax.add, + pool_size, + strides, + padding, + ) + return pooled / window_counts + + +def _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format="channels_last", + transpose=False, +): + """Create a `lax.ConvDimensionNumbers` for the given inputs.""" + num_dims = num_spatial_dims + 2 + + if data_format == "channels_last": + spatial_dims = tuple(range(1, num_dims - 1)) + inputs_dn = (0, num_dims - 1) + spatial_dims + else: + spatial_dims = tuple(range(2, num_dims)) + inputs_dn = (0, 1) + spatial_dims + + if transpose: + kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) + else: + kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) + + return lax.ConvDimensionNumbers( + lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + if data_format == "channels_last": + channels = inputs.shape[-1] + else: + channels = inputs.shape[1] + kernel_in_channels = kernel.shape[-2] + if channels % kernel_in_channels > 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"kernel's in_channels. Received input channels {channels} and " + f"kernel in_channels {kernel_in_channels}. " + ) + feature_group_count = channels // kernel_in_channels + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + feature_group_count = ( + inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] + ) + kernel = np.reshape( + kernel if is_tensor(kernel) else kernel.numpy(), + kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), + ) + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + padding_values = compute_conv_transpose_padding_args_for_jax( + input_shape=inputs.shape, + kernel_shape=kernel.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + + return np.array( + jax.lax.conv_transpose( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding=padding_values, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + transpose_kernel=True, + ) + ) + + +def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + x = convert_to_tensor(x) + input_shape = x.shape + + # Shrink the last dimension if the shape is (..., 1). + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + + x = x.reshape(-1) + if not num_classes: + num_classes = np.max(x) + 1 + + batch_size = x.shape[0] + categorical = np.zeros((batch_size, num_classes), dtype=dtype) + valid_indices = x >= 0 + categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 + + # First, reshape the array with the extra dimension at the end + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + + # Then, move this new dimension to the right place (according to axis) + if axis != -1: + categorical = np.moveaxis(categorical, -1, axis) + + return categorical + + +def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + x = convert_to_tensor(x) + reduction_axis = 1 if len(x.shape) > 1 else 0 + outputs = np.max( + one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), + axis=reduction_axis, + ) + return outputs + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = np.log(output) + return -np.sum(target * log_prob, axis=axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target, dtype="int32") + output = np.array(output) + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = np.squeeze(target, axis=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + if target.shape != output.shape[:-1]: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = np.log(output) + target = one_hot(target, output.shape[axis], axis=axis) + return -np.sum(target * log_prob, axis=axis) + + +def binary_crossentropy(target, output, from_logits=False): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + output = sigmoid(output) + + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + bce = target * np.log(output) + bce += (1.0 - target) * np.log(1.0 - output) + return -bce + + +def moments(x, axes, keepdims=False, synchronized=False): + if synchronized: + raise NotImplementedError( + "Argument synchronized=True is not supported with NumPy." + ) + axes = tuple(axes) if isinstance(axes, list) else axes + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = backend.standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = np.mean(x, axes, keepdims=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) + + if not keepdims: + mean = np.squeeze(mean, axes) + variance = np.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) + variance = np.clip( + variance, np.finfo(np.float16).min, np.finfo(np.float16).max + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + shape = [1] * len(x.shape) + shape[axis] = mean.shape[0] + mean = np.reshape(mean, shape) + variance = np.reshape(variance, shape) + + inv = 1.0 / np.sqrt(variance + epsilon) + if scale is not None: + scale = np.reshape(scale, shape) + inv = inv * scale + + res = -mean * inv + if offset is not None: + offset = np.reshape(offset, shape) + res = res + offset + + return x * inv + res + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + # Ref: https://github.com/google-deepmind/optax + # optax.ctc_loss_with_forward_probs + target = convert_to_tensor(target, dtype="int32") + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length, "int32") + output_length = convert_to_tensor(output_length, "int32") + batch_size, max_input_length, num_classes = output.shape + batch_size, max_label_length = target.shape + log_epsilon = -1e5 + + # Ensure that the dtype promotion behavior matchs that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = output.astype(dtype) + + def _lengths_to_paddings(lengths, max_length): + indices = np.arange(max_length).reshape( + (1,) * lengths.ndim + (max_length,) + ) + lengths = np.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return np.logical_not(elem_valid) + + target_paddings = _lengths_to_paddings(target_length, max_label_length) + output_paddings = _lengths_to_paddings(output_length, max_input_length) + target_paddings = target_paddings.astype(output.dtype) + output_paddings = output_paddings.astype(output.dtype) + + logprobs = log_softmax(output, axis=-1) + label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( + np.int32 + ) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) + repeat = np.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, mask_index: mask_index + 1] # [B, T, 1] + logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] + logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) + logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + # [B, N] + logalpha_phi_init = ( + np.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon + ) + logalpha_phi_init[:, 0] = 0.0 + logalpha_emit_init = ( + np.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon + ) + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return np.concatenate( + [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 + ) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = np.logaddexp( + prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit + ) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) + ) + + pad = pad.reshape((batch_size, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + def np_scan(f, init, xs): + carry = init + ys = [] + for x in zip(*xs): + carry, y = f(carry, x) + ys.append(y) + result = [] + for i in range(len(ys[0])): + result.append(np.stack([y[i] for y in ys])) + return carry, result + + xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = np_scan( + loop_body, (logalpha_phi_init, logalpha_emit_init), xs + ) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi[-1] = logalpha_phi_last + + # extract per_seq_loss + # [B, N+1] + _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) + per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) + return per_seq_loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = np.argmax(inputs, axis=-1).astype("int32") + scores = np.max(inputs, axis=-1) + + seqlen_mask = np.arange(max_length)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = np.where(seqlen_mask, mask_index, indices) + scores = np.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat_mask = indices[:, 1:] == indices[:, :-1] + repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) + indices = np.where(repeat_mask, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = np.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = np.expand_dims(np.arange(max_length), axis=0) # [1, N] + order = np.tile(order, (batch_size, 1)) # [B, N] + order = np.where(invalid_mask, max_length, order) + order = np.argsort(order, axis=-1) + indices = np.take_along_axis(indices, order, axis=-1) + + scores = -np.sum(scores, axis=1)[:, None] + indices = np.expand_dims(indices, axis=0) + return indices, scores + + +def _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths) + + batch_size, max_seq_len, num_classes = inputs.shape + inputs = log_softmax(inputs, axis=-1) + seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] + + if mask_index is None: + mask_index = num_classes - 1 + + # This is a workaround for the fact that np.argsort does not support + # the order parameter which is used to break ties when scores are equal. + # For compatibility with the tensorflow implementation, we flip the inputs + # and the mask_index, and then flip the classes back to the correct indices + inputs = np.flip(inputs, axis=2) + mask_index = num_classes - mask_index - 1 + + _pad = -1 + + init_paths = np.full( + (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 + ) + + num_init_paths = np.min(np.array([num_classes, beam_width])) + max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] + init_classes = np.where(max_classes == mask_index, _pad, max_classes) + init_paths[:, :num_init_paths, 0] = init_classes + + init_scores = np.full( + (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype + ) + init_scores[:, :num_init_paths] = np.take_along_axis( + inputs[:, 0], max_classes, axis=1 + ) + init_masked = init_paths[:, :, 0] == _pad + + def _extend_paths(paths, scores, masked, x): + paths = np.repeat(paths, num_classes, axis=0) + scores = np.repeat(scores, num_classes) + masked = np.repeat(masked, num_classes) + + path_tail_index = np.argmax(paths == _pad, axis=1) + paths_arange = np.arange(2 * beam_width * num_classes) + path_tails = paths[paths_arange, path_tail_index - 1] + path_tails = np.where(path_tail_index == 0, _pad, path_tails) + + classes = np.arange(num_classes) + classes[mask_index] = _pad + classes = np.tile(classes, 2 * beam_width) + + prev_masked = masked + masked = classes == _pad + + masked_repeat = ~prev_masked & (path_tails == classes) + classes = np.where(masked_repeat, _pad, classes) + paths[paths_arange, path_tail_index] = classes + + x = np.tile(x, 2 * beam_width) + scores = scores + x + + return paths, scores, masked + + def _merge_scores(unique_inverse, scores): + scores_max = np.max(scores) + scores_exp = np.exp(scores - scores_max) + scores = np.zeros_like(scores) + for i, u in enumerate(unique_inverse): + scores[u] += scores_exp[i] + scores = np.log(scores) + scores_max + return scores + + def _prune_paths(paths, scores, masked): + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + + emit_scores = np.where(masked, -np.inf, scores) + mask_scores = np.where(masked, scores, -np.inf) + + emit_scores = _merge_scores(unique_inverse, emit_scores) + mask_scores = _merge_scores(unique_inverse, mask_scores) + + total_scores = np.logaddexp(emit_scores, mask_scores) + top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] + + paths = paths[top_indices] + emit_scores = emit_scores[top_indices] + mask_scores = mask_scores[top_indices] + + paths = np.tile(paths, (2, 1)) + scores = np.concatenate([emit_scores, mask_scores]) + masked = np.concatenate( + [np.zeros(beam_width, bool), np.ones(beam_width, bool)] + ) + + return paths, scores, masked + + def _decode_step(paths, scores, masked, x): + paths, scores, masked = _extend_paths(paths, scores, masked, x) + paths, scores, masked = _prune_paths(paths, scores, masked) + return paths, scores, masked + + def _step(prev, x): + paths, scores, masked = prev + x, seqlen_mask = x + if not seqlen_mask: + paths, scores, masked = _decode_step(paths, scores, masked, x) + return (paths, scores, masked), None + + def _decode_batch( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ): + def np_scan_only_carry(f, init, xs): + carry = init + for x in zip(*xs): + carry, y = f(carry, x) + return carry, None + + (paths, scores, masked), _ = np_scan_only_carry( + _step, + (init_paths, init_scores, init_masked), + (inputs[1:], seqlen_mask[1:]), + ) + + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + scores = _merge_scores(unique_inverse, scores) + + top_indices = np.argsort(scores)[-top_paths:][::-1] + paths = paths[top_indices] + scores = scores[top_indices] + + return paths, scores + + results = [ + _decode_batch(p, s, m, i, sm) + for p, s, m, i, sm in zip( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ) + ] + paths = np.stack([r[0] for r in results]) + scores = np.stack([r[1] for r in results]) + + # convert classes back to the correct indices + paths = np.where(paths == _pad, _pad, num_classes - paths - 1) + paths = np.transpose(paths, [1, 0, 2]) + return paths, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + return _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + mask_index=mask_index, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = np.mean(np.square(x1 - x2)) + psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) + return psnr diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py new file mode 100644 index 00000000000..8bad937dc35 --- /dev/null +++ b/keras/src/backend/openvino/numpy.py @@ -0,0 +1,1087 @@ +import numpy as np + +from keras.src import tree +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import standardize_axis_for_numpy +from keras.src.backend.openvino.core import convert_to_tensor, ov_to_keras_type, OPENVINO_DTYPES + + +def _align_operand_types(x1, x2, op_name): + from openvino.runtime.opset14 import convert + x1_type = x1.element_type + x2_type = x2.element_type + if x1_type.is_dynamic() or x2_type.is_dynamic(): + raise ValueError( + f"'{op_name}' operation is not supported for dynamic operand type with openvino backend" + ) + x1_type = ov_to_keras_type(x1_type) + x2_type = ov_to_keras_type(x2_type) + result_type = dtypes.result_type(x1_type, x2_type) + result_type = OPENVINO_DTYPES[result_type] + if x1_type != result_type: + x1 = convert(x1, result_type) + if x2_type != result_type: + x2 = convert(x2, result_type) + return x1, x2 + + +def add(x1, x2): + from openvino.runtime.opset14 import add + x1, x2 = _align_operand_types(x1, x2, "add()") + return add(x1, x2) + + +def einsum(subscripts, *operands, **kwargs): + operands = tree.map_structure(convert_to_tensor, operands) + dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) + # When operands are of int8, we cast the result to int32 to align with + # the behavior of jax. + if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int32" # prevent overflow + result_dtype = "int32" + else: + result_dtype = dtypes.result_type(*dtypes_to_resolve) + compute_dtype = result_dtype + # TODO: np.einsum doesn't support bfloat16 + if compute_dtype == "bfloat16": + compute_dtype = "float32" + operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands) + return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype) + + +def subtract(x1, x2): + from openvino.runtime.opset14 import subtract + x1, x2 = _align_operand_types(x1, x2, "subtract()") + return subtract(x1, x2) + + +def matmul(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + # When both x1 and x2 are of int8, we cast the outputs to int32 to align + # with jax + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if x1_dtype == "int8" and x2_dtype == "int8": + dtype = "int32" + else: + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return np.matmul(x1, x2).astype(dtype) + + +def multiply(x1, x2): + from openvino.runtime.opset14 import multiply + x1, x2 = _align_operand_types(x1, x2, "multiply()") + return multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = dtypes.result_type(x.dtype, "float32") + else: + result_dtype = ori_dtype + return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype) + + +def max(x, axis=None, keepdims=False, initial=None): + axis = standardize_axis_for_numpy(axis) + return np.max(x, axis=axis, keepdims=keepdims, initial=initial) + + +def ones(shape, dtype=None): + dtype = dtype or config.floatx() + return np.ones(shape, dtype=dtype) + + +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() + return np.zeros(shape, dtype=dtype) + + +def absolute(x): + from openvino.runtime.opset14 import absolute + return absolute(x) + + +def abs(x): + from openvino.runtime.opset14 import absolute + return absolute(x) + + +def all(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.all(x, axis=axis, keepdims=keepdims) + + +def any(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.any(x, axis=axis, keepdims=keepdims) + + +def amax(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.amax(x, axis=axis, keepdims=keepdims) + + +def amin(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.amin(x, axis=axis, keepdims=keepdims) + + +def append(x1, x2, axis=None): + axis = standardize_axis_for_numpy(axis) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.append(x1, x2, axis=axis) + + +def arange(start, stop=None, step=None, dtype=None): + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(step, "dtype", type(step)), + ] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + dtype = dtypes.result_type(*dtypes_to_resolve) + return np.arange(start, stop, step=step, dtype=dtype) + + +def arccos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arccos(x) + + +def arccosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arccosh(x) + + +def arcsin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arcsin(x) + + +def arcsinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arcsinh(x) + + +def arctan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arctan(x) + + +def arctan2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.arctan2(x1, x2) + + +def arctanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arctanh(x) + + +def argmax(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") + + +def argmin(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") + + +def argsort(x, axis=-1): + axis = standardize_axis_for_numpy(axis) + return np.argsort(x, axis=axis).astype("int32") + + +def array(x, dtype=None): + return convert_to_tensor(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = x.astype(dtype) + if weights is not None: + weights = weights.astype(dtype) + return np.average(x, weights=weights, axis=axis) + + +def bincount(x, weights=None, minlength=0, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + else: + dtype = "int32" + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return np.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return np.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return np.stack(bincounts).astype(dtype) + return np.bincount(x, weights, minlength).astype(dtype) + + +def broadcast_to(x, shape): + return np.broadcast_to(x, shape) + + +def ceil(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.ceil(x) + + +def clip(x, x_min, x_max): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype == "bool": + dtype = "int32" + return np.clip(x, x_min, x_max).astype(dtype) + + +def concatenate(xs, axis=0): + axis = standardize_axis_for_numpy(axis) + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.concatenate(xs, axis=axis) + + +def conjugate(x): + return np.conjugate(x) + + +def conj(x): + return conjugate(x) + + +def copy(x): + return np.copy(x) + + +def cos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.cos(x) + + +def cosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.cosh(x) + + +def count_nonzero(x, axis=None): + axis = standardize_axis_for_numpy(axis) + # np.count_nonzero will return python int when axis=None, so we need + # to convert_to_tensor + return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32") + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + axis = standardize_axis_for_numpy(axis) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.cross( + x1, + x2, + axisa=axisa, + axisb=axisb, + axisc=axisc, + axis=axis, + ) + + +def cumprod(x, axis=None, dtype=None): + axis = standardize_axis_for_numpy(axis) + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + return np.cumprod(x, axis=axis, dtype=dtype) + + +def cumsum(x, axis=None, dtype=None): + axis = standardize_axis_for_numpy(axis) + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + return np.cumsum(x, axis=axis, dtype=dtype) + + +def diag(x, k=0): + return np.diag(x, k=k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + axis1 = standardize_axis_for_numpy(axis1) + axis2 = standardize_axis_for_numpy(axis2) + return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + + +def diff(a, n=1, axis=-1): + return np.diff(a, n=n, axis=axis) + + +def digitize(x, bins): + return np.digitize(x, bins).astype(np.int32) + + +def dot(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.dot(x, y) + + +def empty(shape, dtype=None): + dtype = dtype or config.floatx() + return np.empty(shape, dtype=dtype) + + +def equal(x1, x2): + return np.equal(x1, x2) + + +def exp(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.exp(x) + + +def expand_dims(x, axis): + axis = standardize_axis_for_numpy(axis) + return np.expand_dims(x, axis) + + +def expm1(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.expm1(x) + + +def flip(x, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.flip(x, axis=axis) + + +def floor(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = x.astype(dtype) + return np.floor(x) + + +def full(shape, fill_value, dtype=None): + dtype = dtype or config.floatx() + return np.full(shape, fill_value, dtype=dtype) + + +def full_like(x, fill_value, dtype=None): + return np.full_like(x, fill_value, dtype=dtype) + + +def greater(x1, x2): + return np.greater(x1, x2) + + +def greater_equal(x1, x2): + return np.greater_equal(x1, x2) + + +def hstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.hstack(xs) + + +def identity(n, dtype=None): + dtype = dtype or config.floatx() + return np.identity(n, dtype=dtype) + + +def imag(x): + return np.imag(x) + + +def isclose(x1, x2): + return np.isclose(x1, x2) + + +def isfinite(x): + return np.isfinite(x) + + +def isinf(x): + return np.isinf(x) + + +def isnan(x): + return np.isnan(x) + + +def less(x1, x2): + return np.less(x1, x2) + + +def less_equal(x1, x2): + return np.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + axis = standardize_axis_for_numpy(axis) + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + return np.linspace( + start, + stop, + num=num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + + +def log(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log(x, dtype=dtype) + + +def log10(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log10(x, dtype=dtype) + + +def log1p(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log1p(x, dtype=dtype) + + +def log2(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log2(x, dtype=dtype) + + +def logaddexp(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.logaddexp(x1, x2) + + +def logical_and(x1, x2): + return np.logical_and(x1, x2) + + +def logical_not(x): + return np.logical_not(x) + + +def logical_or(x1, x2): + return np.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + return np.logspace( + start, + stop, + num=num, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, + ) + + +def maximum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.maximum(x1, x2) + + +def median(x, axis=None, keepdims=False): + dtype = dtypes.result_type(x.dtype, float) + return np.median(x, axis=axis, keepdims=keepdims).astype(dtype) + + +def meshgrid(*x, indexing="xy"): + return np.meshgrid(*x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + axis = standardize_axis_for_numpy(axis) + return np.min(x, axis=axis, keepdims=keepdims, initial=initial) + + +def minimum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.minimum(x1, x2) + + +def mod(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype == "bool": + dtype = "int32" + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.mod(x1, x2) + + +def moveaxis(x, source, destination): + return np.moveaxis(x, source=source, destination=destination) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +def ndim(x): + return np.ndim(x) + + +def nonzero(x): + return tuple(indices.astype("int32") for indices in np.nonzero(x)) + + +def not_equal(x1, x2): + return np.not_equal(x1, x2) + + +def zeros_like(x, dtype=None): + return np.zeros_like(x, dtype=dtype) + + +def ones_like(x, dtype=None): + return np.ones_like(x, dtype=dtype) + + +def outer(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.outer(x1, x2) + + +def pad(x, pad_width, mode="constant", constant_values=None): + kwargs = {} + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + kwargs["constant_values"] = constant_values + return np.pad(x, pad_width, mode=mode, **kwargs) + + +def prod(x, axis=None, keepdims=False, dtype=None): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + if dtype is None: + dtype = dtypes.result_type(x.dtype) + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + + ori_dtype = standardize_dtype(x.dtype) + # np.quantile doesn't support bool + if ori_dtype == "bool": + x = x.astype(config.floatx()) + if ori_dtype == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + return np.quantile( + x, q, axis=axis, method=method, keepdims=keepdims + ).astype(dtype) + + +def ravel(x): + return np.ravel(x) + + +def real(x): + return np.real(x) + + +def reciprocal(x): + return np.reciprocal(x) + + +def repeat(x, repeats, axis=None): + return np.repeat(x, repeats, axis=axis) + + +def reshape(x, newshape): + return np.reshape(x, newshape) + + +def roll(x, shift, axis=None): + return np.roll(x, shift, axis=axis) + + +def sign(x): + return np.sign(x) + + +def sin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.sin(x) + + +def sinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.sinh(x) + + +def size(x): + return np.size(x) + + +def sort(x, axis=-1): + axis = standardize_axis_for_numpy(axis) + return np.sort(x, axis=axis) + + +def split(x, indices_or_sections, axis=0): + axis = standardize_axis_for_numpy(axis) + return np.split(x, indices_or_sections, axis=axis) + + +def stack(x, axis=0): + axis = standardize_axis_for_numpy(axis) + dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x) + return np.stack(x, axis=axis) + + +def std(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.std(x, axis=axis, keepdims=keepdims) + + +def swapaxes(x, axis1, axis2): + return np.swapaxes(x, axis1=axis1, axis2=axis2) + + +def take(x, indices, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.take(x, indices, axis=axis) + + +def take_along_axis(x, indices, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.take_along_axis(x, indices, axis=axis) + + +def tan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.tan(x) + + +def tanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.tanh(x) + + +def tensordot(x1, x2, axes=2): + axes = tuple(axes) if isinstance(axes, list) else axes + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.tensordot(x1, x2, axes=axes) + + +def round(x, decimals=0): + return np.round(x, decimals=decimals) + + +def tile(x, repeats): + return np.tile(x, repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + axis1 = standardize_axis_for_numpy(axis1) + axis2 = standardize_axis_for_numpy(axis2) + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype not in ("int64", "uint32", "uint64"): + dtype = dtypes.result_type(dtype, "int32") + return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + + +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return np.tri(N, M=M, k=k, dtype=dtype) + + +def tril(x, k=0): + return np.tril(x, k=k) + + +def triu(x, k=0): + return np.triu(x, k=k) + + +def vdot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.vdot(x1, x2) + + +def vstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.vstack(xs) + + +def vectorize(pyfunc, *, excluded=None, signature=None): + return np.vectorize(pyfunc, excluded=excluded, signature=signature) + + +def where(condition, x1, x2): + if x1 is not None and x2 is not None: + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.where(condition, x1, x2) + else: + return np.where(condition) + + +def divide(x1, x2): + from openvino.runtime.opset14 import divide + x1, x2 = _align_operand_types(x1, x2) + return divide(x1, x2) + + +def divide_no_nan(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.where(x2 == 0, 0, np.divide(x1, x2)) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + from openvino.runtime.opset14 import power + x1, x2 = _align_operand_types(x1, x2) + return power(x1, x2) + + +def negative(x): + from openvino.runtime.opset14 import negative + return negative(x) + + +def square(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = x.astype("int32") + return np.square(x) + + +def sqrt(x): + x = convert_to_tensor(x) + # upcast to float64 for int64 which matches JAX's behavior + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.sqrt(x, dtype=dtype) + + +def squeeze(x, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.squeeze(x, axis=axis) + + +def transpose(x, axes=None): + axes = tuple(axes) if isinstance(axes, list) else axes + return np.transpose(x, axes=axes) + + +def var(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype( + result_dtype + ) + + +def sum(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + dtype = standardize_dtype(x.dtype) + # follow jax's rule + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype) + + +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return np.eye(N, M=M, k=k, dtype=dtype) + + +def floor_divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.floor_divide(x1, x2) + + +def logical_xor(x1, x2): + from openvino.runtime.opset14 import logical_xor + return logical_xor(x1, x2) + + +def correlate(x1, x2, mode="valid"): + raise NotImplementedError( + "`correlate` is not supported with openvino backend" + ) + + +def select(condlist, choicelist, default=0): + raise NotImplementedError( + "`select` is not supported with openvino backend" + ) + + +def slogdet(x): + raise NotImplementedError( + "`slogdet` is not supported with openvino backend" + ) + + +def argpartition(x, kth, axis=-1): + raise NotImplementedError( + "`argpartition` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py new file mode 100644 index 00000000000..028544e53b7 --- /dev/null +++ b/keras/src/backend/openvino/random.py @@ -0,0 +1,120 @@ +import numpy as np + +from keras.src.backend.config import floatx +from keras.src.backend.openvino.nn import softmax +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = [] + for logits_instance in logits: + probabilities = softmax(logits_instance) + classes = np.arange(logits_instance.shape[-1]) + samples = rng.choice(classes, size=num_samples, p=probabilities) + output.append(samples) + return np.array(output).astype(dtype) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype) + return output + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + + lower_bound = mean - 2 * stddev + upper_bound = mean + 2 * stddev + + flat_shape = np.prod(shape) + random_numbers = np.empty(0) + + # loop until we have enough valid numbers to fill our desired shape + while random_numbers.shape[0] < flat_shape: + # Generate a batch of random numbers from a normal distribution + batch = rng.normal(loc=mean, scale=stddev, size=flat_shape) + + # Filter the numbers to keep only those within the specified bounds + valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] + + # Append the valid numbers to the result array + random_numbers = np.append(random_numbers, valid) + + # Truncate the result array to the desired size and reshape it + return random_numbers[:flat_shape].astype(dtype).reshape(shape) + + +def dropout(inputs, rate, noise_shape=None, seed=None): + dtype = inputs.dtype + seed = draw_seed(seed) + + keep_prob = 1.0 - rate + + # If noise_shape is not provided, use the shape of inputs + if noise_shape is None: + noise_shape = inputs.shape + else: + # If noise_shape is provided, replace None with corresponding + # input shape + noise_shape = [ + n if n is not None else inputs.shape[i] + for i, n in enumerate(noise_shape) + ] + + rng = np.random.default_rng(seed) + mask = rng.uniform(size=noise_shape) < keep_prob + mask = np.broadcast_to(mask, inputs.shape) + return np.where( + mask, (inputs / keep_prob).astype(dtype), np.zeros_like(inputs) + ) + + +def shuffle(x, axis=0, seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.permuted(x, axis=axis) + + +def gamma(shape, alpha, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.gamma(alpha, scale=1.0, size=shape).astype(dtype) + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + sample = rng.binomial(n=counts, p=probabilities, size=shape).astype(dtype) + return sample + + +def beta(shape, alpha, beta, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + sample = rng.beta(a=alpha, b=beta, size=shape).astype(dtype) + return sample diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py new file mode 100644 index 00000000000..07f65752514 --- /dev/null +++ b/keras/src/backend/openvino/rnn.py @@ -0,0 +1,239 @@ +import numpy as np + +from keras.src import tree + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return np.transpose(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + + if mask is not None: + if mask.dtype != "bool": + mask = mask.astype("bool") + if len(mask.shape) == 2: + mask = np.expand_dims(mask, axis=-1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = np.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) + return np.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tree.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = np.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = np.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + np.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = tree.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # Unroll == False + if mask is not None: + + def _step(states, current_input): + current_input, current_mask = current_input + is_masked = np.all( + np.logical_not(current_mask), axis=-1, keepdims=True + ) + + output_t, new_states = step_function(current_input, states) + + if zero_output_for_mask: + masked_outs = np.where( + is_masked, np.zeros_like(output_t), output_t + ) + else: + # Assume the first state is the previous output. + output_tm1 = states[0] + masked_outs = np.where(is_masked, output_tm1, output_t) + + new_states = [ + np.where(is_masked, s, ns) + for s, ns in zip(states, new_states) + ] + return (new_states, masked_outs) + + scan_xs = (inputs, mask) + + else: + + def _step(states, current_input): + output_t, new_states = step_function(current_input, states) + return new_states, output_t + + scan_xs = inputs + + new_states, outputs = numpy_scan( + f=_step, + init=initial_states, + xs=scan_xs, + reverse=go_backwards, + mask=mask, + ) + + if go_backwards: + outputs = np.flip(outputs, axis=0) + last_output = outputs[-1] + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def lstm(*args, **kwargs): + raise NotImplementedError + + +def gru(*args, **kwargs): + raise NotImplementedError + + +def unstack(x, axis=0): + return [x.take(i, axis) for i in range(x.shape[axis])] + + +def numpy_scan(f, init, xs, reverse=False, mask=None): + states = init + outputs = [] + + if mask is not None: + x, mask = xs + x = np.flip(x, axis=0) if reverse else x + mask = np.flip(mask, axis=0) if reverse else mask + + for each_x, each_mask in zip(x, mask): + states, output = f(states, (each_x, each_mask)) + outputs.append(output) + else: + xs = np.flip(xs, axis=0) if reverse else xs + + for x in xs: + states, output = f(states, x) + outputs.append(output) + + outputs = np.array(outputs) + + if reverse: + outputs = np.flip(outputs, axis=0) + + return states, outputs + + +def cudnn_ok(*args, **kwargs): + return False diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py new file mode 100644 index 00000000000..5c510d6a958 --- /dev/null +++ b/keras/src/backend/openvino/trainer.py @@ -0,0 +1,322 @@ +import numpy as np + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import tree +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.openvino.core import is_tensor +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class OpenVINOTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.test_function = None + self.predict_function = None + + def test_step(self, data): + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self.compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + self._loss_tracker.update_state( + loss, sample_weight=tree.flatten(x)[0].shape[0] + ) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + def one_test_step(data): + data = data[0] + return self.test_step(data) + + def multi_test_steps(data): + for single_step_data in data: + logs = one_test_step([single_step_data]) + return logs + + if self.steps_per_execution > 1: + test_step = multi_test_steps + else: + test_step = one_test_step + + self.test_function = test_step + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def one_predict_step(data): + data = data[0] + return self.predict_step(data) + + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, + ) + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step + + def _symbolic_build(self, data_batch): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + if model_unbuilt or compile_metrics_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if is_tensor(v): + return KerasTensor(v.shape, standardize_dtype(v.dtype)) + return v + + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x) + except: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it." + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + self._post_build() + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + raise NotImplementedError("fit not implemented for OpenVINO backend.") + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + for step, data in epoch_iterator.enumerate_epoch(): + callbacks.on_predict_batch_begin(step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + if self.stop_predicting: + break + callbacks.on_predict_end() + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = EpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + if not all(layer.built for layer in self._flatten_layers()): + # Build the model on one batch of data. + for _, data in epoch_iterator.enumerate_epoch(): + data_batch = data[0] + self._symbolic_build(data_batch) + break + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = None + self.reset_metrics() + for step, data in epoch_iterator.enumerate_epoch(): + callbacks.on_test_batch_begin(step) + logs = self.test_function(data) + logs = self._pythonify_logs(logs) + callbacks.on_test_batch_end(step, logs) + if self.stop_evaluating: + break + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "train_on_batch not implemented for OpenVINO backend." + ) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + data = (x, y, sample_weight) + + # Maybe build model + self._symbolic_build(data) + self.make_test_function() + + logs = self.test_function([data]) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 99856bbea04..763edae9a10 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -50,6 +50,8 @@ from keras.src.backend.torch.layer import TorchLayer as BackendLayer elif backend.backend() == "numpy": from keras.src.backend.numpy.layer import NumpyLayer as BackendLayer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.layer import NumpyLayer as BackendLayer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement a layer mixin class." @@ -289,6 +291,9 @@ def __init__( self._convert_input_args = True # Whether to allow non-tensors as positional arguments in `call()`. self._allow_non_tensor_positional_args = False + if backend.backend() == "openvino": + self._allow_non_tensor_positional_args = True + # Dict of shapes that were used to call `build()`. self._build_shapes_dict = None # Parent path diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 85533d0a32e..c4945843453 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -163,8 +163,18 @@ def layers(self): return layers def call(self, inputs, training=None, mask=None): + from keras.src.backend.config import backend # Add support for traning, masking inputs = self._standardize_inputs(inputs) + if backend() == "openvino": + from keras.src.backend.openvino.core import get_device + if self._ov_device != get_device(): + # update the current device and re-compile a model + self._ov_device = get_device() + self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._openvino_device) + outputs = self._ov_compiled_model(inputs) + return unpack_singleton(tree.pack_sequence_as(self._outputs_struct, outputs.to_tuple())) + if mask is None: masks = [None] * len(inputs) else: diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 3039275c750..c573b4563b6 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -23,6 +23,8 @@ from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "numpy": from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement the Trainer class." diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 49df4073e06..ca60ad36ae5 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -81,6 +81,32 @@ def __init__(self, inputs, outputs, name=None): self._nodes_by_depth = nodes_by_depth self._operations = operations self._operations_by_depth = operations_by_depth + if backend() == "openvino": + from keras.src.backend.openvino.core import OPENVINO_DTYPES + from keras.src.backend.openvino.core import get_device + import openvino as ov + import openvino.runtime.opset14 as ov_opset + from openvino import Core + # prepare OpenVINO parameters + ov_inputs = [] + for _input in self._inputs: + ov_type = OPENVINO_DTYPES[_input.dtype] + ov_shape = _input.shape + ov_shape = list(ov_shape) + for i in range(len(ov_shape)): + if ov_shape[i] is None: + ov_shape[i] = -1 + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + ov_inputs.append(param) + pass + # build OpenVINO graph - ov.Model + ov_outputs = self._run_through_graph(ov_inputs, operation_fn=lambda op: op) + ov_outputs = tree.flatten(ov_outputs) + self._ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) + self._ov_core = Core() + self._ov_device = get_device() + self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._ov_device) + pass @property def operations(self): @@ -121,6 +147,15 @@ def compute_output_spec(self, inputs): def call(self, inputs): """Computes output tensors for new inputs.""" self._assert_input_compatibility(inputs) + if backend() == "openvino": + from keras.src.backend.openvino.core import get_device + if self._ov_device != get_device(): + # update the current device and re-compile a model + self._ov_device = get_device() + self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._openvino_device) + inputs = tree.flatten(inputs) + outputs = self._ov_compiled_model(inputs) + return tree.pack_sequence_as(self._outputs_struct, outputs.to_tuple()) return self._run_through_graph(inputs, operation_fn=lambda op: op) def _run_through_graph(self, inputs, operation_fn, call_fn=None): diff --git a/keras/src/utils/backend_utils.py b/keras/src/utils/backend_utils.py index 9a82fd464eb..f9ea4a02348 100644 --- a/keras/src/utils/backend_utils.py +++ b/keras/src/utils/backend_utils.py @@ -86,7 +86,10 @@ def __getattr__(self, name): from keras.src import backend as numpy_backend return getattr(numpy_backend, name) + if self._backend == "openvino": + from keras.src.backend import openvino as openvino_backend + return getattr(openvino_backend, name) @keras_export("keras.config.set_backend") def set_backend(backend): From 3e0772f809e3c34b35302189a4ce5246ba62efa7 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 23 Sep 2024 21:13:54 +0400 Subject: [PATCH 02/19] Mark all unsupported ops from numpy space Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/numpy.py | 956 +++++++++++----------------- 1 file changed, 371 insertions(+), 585 deletions(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 8bad937dc35..f6d91af9456 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,10 +1,4 @@ -import numpy as np - -from keras.src import tree -from keras.src.backend import config -from keras.src.backend import standardize_dtype from keras.src.backend.common import dtypes -from keras.src.backend.common.backend_utils import standardize_axis_for_numpy from keras.src.backend.openvino.core import convert_to_tensor, ov_to_keras_type, OPENVINO_DTYPES @@ -34,21 +28,9 @@ def add(x1, x2): def einsum(subscripts, *operands, **kwargs): - operands = tree.map_structure(convert_to_tensor, operands) - dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) - # When operands are of int8, we cast the result to int32 to align with - # the behavior of jax. - if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": - compute_dtype = "int32" # prevent overflow - result_dtype = "int32" - else: - result_dtype = dtypes.result_type(*dtypes_to_resolve) - compute_dtype = result_dtype - # TODO: np.einsum doesn't support bfloat16 - if compute_dtype == "bfloat16": - compute_dtype = "float32" - operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands) - return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype) + raise NotImplementedError( + "`einsum` is not supported with openvino backend" + ) def subtract(x1, x2): @@ -58,17 +40,9 @@ def subtract(x1, x2): def matmul(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - # When both x1 and x2 are of int8, we cast the outputs to int32 to align - # with jax - x1_dtype = standardize_dtype(x1.dtype) - x2_dtype = standardize_dtype(x2.dtype) - if x1_dtype == "int8" and x2_dtype == "int8": - dtype = "int32" - else: - dtype = dtypes.result_type(x1.dtype, x2.dtype) - return np.matmul(x1, x2).astype(dtype) + raise NotImplementedError( + "`matmul` is not supported with openvino backend" + ) def multiply(x1, x2): @@ -78,29 +52,27 @@ def multiply(x1, x2): def mean(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - ori_dtype = standardize_dtype(x.dtype) - if "int" in ori_dtype or ori_dtype == "bool": - result_dtype = dtypes.result_type(x.dtype, "float32") - else: - result_dtype = ori_dtype - return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype) + raise NotImplementedError( + "`mean` is not supported with openvino backend" + ) def max(x, axis=None, keepdims=False, initial=None): - axis = standardize_axis_for_numpy(axis) - return np.max(x, axis=axis, keepdims=keepdims, initial=initial) + raise NotImplementedError( + "`max` is not supported with openvino backend" + ) def ones(shape, dtype=None): - dtype = dtype or config.floatx() - return np.ones(shape, dtype=dtype) + raise NotImplementedError( + "`ones` is not supported with openvino backend" + ) def zeros(shape, dtype=None): - dtype = dtype or config.floatx() - return np.zeros(shape, dtype=dtype) + raise NotImplementedError( + "`zeros` is not supported with openvino backend" + ) def absolute(x): @@ -114,129 +86,99 @@ def abs(x): def all(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.all(x, axis=axis, keepdims=keepdims) + raise NotImplementedError( + "`all` is not supported with openvino backend" + ) def any(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.any(x, axis=axis, keepdims=keepdims) + raise NotImplementedError( + "`any` is not supported with openvino backend" + ) def amax(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.amax(x, axis=axis, keepdims=keepdims) + raise NotImplementedError( + "`amax` is not supported with openvino backend" + ) def amin(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.amin(x, axis=axis, keepdims=keepdims) + raise NotImplementedError( + "`amin` is not supported with openvino backend" + ) def append(x1, x2, axis=None): - axis = standardize_axis_for_numpy(axis) - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.append(x1, x2, axis=axis) + raise NotImplementedError( + "`append` is not supported with openvino backend" + ) def arange(start, stop=None, step=None, dtype=None): - if dtype is None: - dtypes_to_resolve = [ - getattr(start, "dtype", type(start)), - getattr(step, "dtype", type(step)), - ] - if stop is not None: - dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) - dtype = dtypes.result_type(*dtypes_to_resolve) - return np.arange(start, stop, step=step, dtype=dtype) + raise NotImplementedError( + "`arange` is not supported with openvino backend" + ) def arccos(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arccos(x) + raise NotImplementedError( + "`arccos` is not supported with openvino backend" + ) def arccosh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arccosh(x) + raise NotImplementedError( + "`arccosh` is not supported with openvino backend" + ) def arcsin(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arcsin(x) + raise NotImplementedError( + "`arcsin` is not supported with openvino backend" + ) def arcsinh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arcsinh(x) + raise NotImplementedError( + "`arcsinh` is not supported with openvino backend" + ) def arctan(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arctan(x) + raise NotImplementedError( + "`arctan` is not supported with openvino backend" + ) def arctan2(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype, float) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.arctan2(x1, x2) + raise NotImplementedError( + "`arctan2` is not supported with openvino backend" + ) def arctanh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.arctanh(x) + raise NotImplementedError( + "`arctanh` is not supported with openvino backend" + ) def argmax(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") + raise NotImplementedError( + "`argmax` is not supported with openvino backend" + ) def argmin(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") + raise NotImplementedError( + "`argmin` is not supported with openvino backend" + ) def argsort(x, axis=-1): - axis = standardize_axis_for_numpy(axis) - return np.argsort(x, axis=axis).astype("int32") + raise NotImplementedError( + "`argsort` is not supported with openvino backend" + ) def array(x, dtype=None): @@ -244,715 +186,587 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - dtypes_to_resolve = [x.dtype, float] - if weights is not None: - weights = convert_to_tensor(weights) - dtypes_to_resolve.append(weights.dtype) - dtype = dtypes.result_type(*dtypes_to_resolve) - x = x.astype(dtype) - if weights is not None: - weights = weights.astype(dtype) - return np.average(x, weights=weights, axis=axis) + raise NotImplementedError( + "`average` is not supported with openvino backend" + ) def bincount(x, weights=None, minlength=0, sparse=False): - if sparse: - raise ValueError("Unsupported value `sparse=True` with numpy backend") - x = convert_to_tensor(x) - dtypes_to_resolve = [x.dtype] - if weights is not None: - weights = convert_to_tensor(weights) - dtypes_to_resolve.append(weights.dtype) - dtype = dtypes.result_type(*dtypes_to_resolve) - else: - dtype = "int32" - if len(x.shape) == 2: - if weights is None: - - def bincount_fn(arr): - return np.bincount(arr, minlength=minlength) - - bincounts = list(map(bincount_fn, x)) - else: - - def bincount_fn(arr_w): - return np.bincount( - arr_w[0], weights=arr_w[1], minlength=minlength - ) - - bincounts = list(map(bincount_fn, zip(x, weights))) - - return np.stack(bincounts).astype(dtype) - return np.bincount(x, weights, minlength).astype(dtype) + raise NotImplementedError( + "`bincount` is not supported with openvino backend" + ) def broadcast_to(x, shape): - return np.broadcast_to(x, shape) + raise NotImplementedError( + "`broadcast_to` is not supported with openvino backend" + ) def ceil(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.ceil(x) + raise NotImplementedError( + "`ceil` is not supported with openvino backend" + ) def clip(x, x_min, x_max): - x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) - if dtype == "bool": - dtype = "int32" - return np.clip(x, x_min, x_max).astype(dtype) + raise NotImplementedError( + "`clip` is not supported with openvino backend" + ) def concatenate(xs, axis=0): - axis = standardize_axis_for_numpy(axis) - dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) - if len(dtype_set) > 1: - dtype = dtypes.result_type(*dtype_set) - xs = tree.map_structure( - lambda x: convert_to_tensor(x).astype(dtype), xs - ) - return np.concatenate(xs, axis=axis) + raise NotImplementedError( + "`concatenate` is not supported with openvino backend" + ) def conjugate(x): - return np.conjugate(x) + raise NotImplementedError( + "`conjugate` is not supported with openvino backend" + ) def conj(x): - return conjugate(x) + raise NotImplementedError( + "`conj` is not supported with openvino backend" + ) def copy(x): - return np.copy(x) + raise NotImplementedError( + "`copy` is not supported with openvino backend" + ) def cos(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.cos(x) + raise NotImplementedError( + "`cos` is not supported with openvino backend" + ) def cosh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.cosh(x) + raise NotImplementedError( + "`cosh` is not supported with openvino backend" + ) def count_nonzero(x, axis=None): - axis = standardize_axis_for_numpy(axis) - # np.count_nonzero will return python int when axis=None, so we need - # to convert_to_tensor - return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32") + raise NotImplementedError( + "`count_nonzero` is not supported with openvino backend" + ) def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): - axis = standardize_axis_for_numpy(axis) - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.cross( - x1, - x2, - axisa=axisa, - axisb=axisb, - axisc=axisc, - axis=axis, + raise NotImplementedError( + "`cross` is not supported with openvino backend" ) def cumprod(x, axis=None, dtype=None): - axis = standardize_axis_for_numpy(axis) - dtype = dtypes.result_type(dtype or x.dtype) - if dtype == "bool": - dtype = "int32" - return np.cumprod(x, axis=axis, dtype=dtype) + raise NotImplementedError( + "`cumprod` is not supported with openvino backend" + ) def cumsum(x, axis=None, dtype=None): - axis = standardize_axis_for_numpy(axis) - dtype = dtypes.result_type(dtype or x.dtype) - if dtype == "bool": - dtype = "int32" - return np.cumsum(x, axis=axis, dtype=dtype) + raise NotImplementedError( + "`cumsum` is not supported with openvino backend" + ) def diag(x, k=0): - return np.diag(x, k=k) + raise NotImplementedError( + "`diag` is not supported with openvino backend" + ) def diagonal(x, offset=0, axis1=0, axis2=1): - axis1 = standardize_axis_for_numpy(axis1) - axis2 = standardize_axis_for_numpy(axis2) - return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + raise NotImplementedError( + "`diagonal` is not supported with openvino backend" + ) def diff(a, n=1, axis=-1): - return np.diff(a, n=n, axis=axis) + raise NotImplementedError( + "`diff` is not supported with openvino backend" + ) def digitize(x, bins): - return np.digitize(x, bins).astype(np.int32) + raise NotImplementedError( + "`digitize` is not supported with openvino backend" + ) def dot(x, y): - x = convert_to_tensor(x) - y = convert_to_tensor(y) - dtype = dtypes.result_type(x.dtype, y.dtype) - x = x.astype(dtype) - y = y.astype(dtype) - return np.dot(x, y) + raise NotImplementedError( + "`dot` is not supported with openvino backend" + ) def empty(shape, dtype=None): - dtype = dtype or config.floatx() - return np.empty(shape, dtype=dtype) + raise NotImplementedError( + "`empty` is not supported with openvino backend" + ) def equal(x1, x2): - return np.equal(x1, x2) + raise NotImplementedError( + "`equal` is not supported with openvino backend" + ) def exp(x): - x = convert_to_tensor(x) - ori_dtype = standardize_dtype(x.dtype) - if "int" in ori_dtype or ori_dtype == "bool": - x = x.astype(config.floatx()) - return np.exp(x) + raise NotImplementedError( + "`exp` is not supported with openvino backend" + ) def expand_dims(x, axis): - axis = standardize_axis_for_numpy(axis) - return np.expand_dims(x, axis) + raise NotImplementedError( + "`expand_dims` is not supported with openvino backend" + ) def expm1(x): - x = convert_to_tensor(x) - ori_dtype = standardize_dtype(x.dtype) - if "int" in ori_dtype or ori_dtype == "bool": - x = x.astype(config.floatx()) - return np.expm1(x) + raise NotImplementedError( + "`expm1` is not supported with openvino backend" + ) def flip(x, axis=None): - axis = standardize_axis_for_numpy(axis) - return np.flip(x, axis=axis) + raise NotImplementedError( + "`flip` is not supported with openvino backend" + ) def floor(x): - x = convert_to_tensor(x) - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`floor` is not supported with openvino backend" ) - x = x.astype(dtype) - return np.floor(x) def full(shape, fill_value, dtype=None): - dtype = dtype or config.floatx() - return np.full(shape, fill_value, dtype=dtype) + raise NotImplementedError( + "`full` is not supported with openvino backend" + ) def full_like(x, fill_value, dtype=None): - return np.full_like(x, fill_value, dtype=dtype) + raise NotImplementedError( + "`full_like` is not supported with openvino backend" + ) def greater(x1, x2): - return np.greater(x1, x2) + raise NotImplementedError( + "`greater` is not supported with openvino backend" + ) def greater_equal(x1, x2): - return np.greater_equal(x1, x2) + raise NotImplementedError( + "`greater_equal` is not supported with openvino backend" + ) def hstack(xs): - dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) - if len(dtype_set) > 1: - dtype = dtypes.result_type(*dtype_set) - xs = tree.map_structure( - lambda x: convert_to_tensor(x).astype(dtype), xs - ) - return np.hstack(xs) + raise NotImplementedError( + "`hstack` is not supported with openvino backend" + ) def identity(n, dtype=None): - dtype = dtype or config.floatx() - return np.identity(n, dtype=dtype) + raise NotImplementedError( + "`identity` is not supported with openvino backend" + ) def imag(x): - return np.imag(x) + raise NotImplementedError( + "`imag` is not supported with openvino backend" + ) def isclose(x1, x2): - return np.isclose(x1, x2) + raise NotImplementedError( + "`isclose` is not supported with openvino backend" + ) def isfinite(x): - return np.isfinite(x) + raise NotImplementedError( + "`isfinite` is not supported with openvino backend" + ) def isinf(x): - return np.isinf(x) + raise NotImplementedError( + "`isinf` is not supported with openvino backend" + ) def isnan(x): - return np.isnan(x) + raise NotImplementedError( + "`isnan` is not supported with openvino backend" + ) def less(x1, x2): - return np.less(x1, x2) + raise NotImplementedError( + "`less` is not supported with openvino backend" + ) def less_equal(x1, x2): - return np.less_equal(x1, x2) + raise NotImplementedError( + "`less_equal` is not supported with openvino backend" + ) def linspace( start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): - axis = standardize_axis_for_numpy(axis) - if dtype is None: - dtypes_to_resolve = [ - getattr(start, "dtype", type(start)), - getattr(stop, "dtype", type(stop)), - float, - ] - dtype = dtypes.result_type(*dtypes_to_resolve) - return np.linspace( - start, - stop, - num=num, - endpoint=endpoint, - retstep=retstep, - dtype=dtype, - axis=axis, + raise NotImplementedError( + "`linspace` is not supported with openvino backend" ) def log(x): - x = convert_to_tensor(x) - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`log` is not supported with openvino backend" ) - return np.log(x, dtype=dtype) def log10(x): - x = convert_to_tensor(x) - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`log10` is not supported with openvino backend" ) - return np.log10(x, dtype=dtype) def log1p(x): - x = convert_to_tensor(x) - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`log1p` is not supported with openvino backend" ) - return np.log1p(x, dtype=dtype) def log2(x): - x = convert_to_tensor(x) - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`log2` is not supported with openvino backend" ) - return np.log2(x, dtype=dtype) def logaddexp(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype, float) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.logaddexp(x1, x2) + raise NotImplementedError( + "`logaddexp` is not supported with openvino backend" + ) def logical_and(x1, x2): - return np.logical_and(x1, x2) + raise NotImplementedError( + "`logical_and` is not supported with openvino backend" + ) def logical_not(x): - return np.logical_not(x) + raise NotImplementedError( + "`logical_not` is not supported with openvino backend" + ) def logical_or(x1, x2): - return np.logical_or(x1, x2) + raise NotImplementedError( + "`logical_or` is not supported with openvino backend" + ) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): - if dtype is None: - dtypes_to_resolve = [ - getattr(start, "dtype", type(start)), - getattr(stop, "dtype", type(stop)), - float, - ] - dtype = dtypes.result_type(*dtypes_to_resolve) - return np.logspace( - start, - stop, - num=num, - endpoint=endpoint, - base=base, - dtype=dtype, - axis=axis, + raise NotImplementedError( + "`logspace` is not supported with openvino backend" ) def maximum(x1, x2): - if not isinstance(x1, (int, float)): - x1 = convert_to_tensor(x1) - if not isinstance(x2, (int, float)): - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type( - getattr(x1, "dtype", type(x1)), - getattr(x2, "dtype", type(x2)), + raise NotImplementedError( + "`maximum` is not supported with openvino backend" ) - x1 = convert_to_tensor(x1, dtype) - x2 = convert_to_tensor(x2, dtype) - return np.maximum(x1, x2) def median(x, axis=None, keepdims=False): - dtype = dtypes.result_type(x.dtype, float) - return np.median(x, axis=axis, keepdims=keepdims).astype(dtype) + raise NotImplementedError( + "`median` is not supported with openvino backend" + ) def meshgrid(*x, indexing="xy"): - return np.meshgrid(*x, indexing=indexing) + raise NotImplementedError( + "`meshgrid` is not supported with openvino backend" + ) def min(x, axis=None, keepdims=False, initial=None): - axis = standardize_axis_for_numpy(axis) - return np.min(x, axis=axis, keepdims=keepdims, initial=initial) + raise NotImplementedError( + "`min` is not supported with openvino backend" + ) def minimum(x1, x2): - if not isinstance(x1, (int, float)): - x1 = convert_to_tensor(x1) - if not isinstance(x2, (int, float)): - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type( - getattr(x1, "dtype", type(x1)), - getattr(x2, "dtype", type(x2)), + raise NotImplementedError( + "`minimum` is not supported with openvino backend" ) - x1 = convert_to_tensor(x1, dtype) - x2 = convert_to_tensor(x2, dtype) - return np.minimum(x1, x2) def mod(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - if dtype == "bool": - dtype = "int32" - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.mod(x1, x2) + raise NotImplementedError( + "`mod` is not supported with openvino backend" + ) def moveaxis(x, source, destination): - return np.moveaxis(x, source=source, destination=destination) + raise NotImplementedError( + "`moveaxis` is not supported with openvino backend" + ) def nan_to_num(x, nan=0.0, posinf=None, neginf=None): - return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + raise NotImplementedError( + "`nan_to_num` is not supported with openvino backend" + ) def ndim(x): - return np.ndim(x) + raise NotImplementedError( + "`ndim` is not supported with openvino backend" + ) def nonzero(x): - return tuple(indices.astype("int32") for indices in np.nonzero(x)) + raise NotImplementedError( + "`nonzero` is not supported with openvino backend" + ) def not_equal(x1, x2): - return np.not_equal(x1, x2) + raise NotImplementedError( + "`not_equal` is not supported with openvino backend" + ) def zeros_like(x, dtype=None): - return np.zeros_like(x, dtype=dtype) + raise NotImplementedError( + "`zeros_like` is not supported with openvino backend" + ) def ones_like(x, dtype=None): - return np.ones_like(x, dtype=dtype) + raise NotImplementedError( + "`ones_like` is not supported with openvino backend" + ) def outer(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.outer(x1, x2) + raise NotImplementedError( + "`outer` is not supported with openvino backend" + ) def pad(x, pad_width, mode="constant", constant_values=None): - kwargs = {} - if constant_values is not None: - if mode != "constant": - raise ValueError( - "Argument `constant_values` can only be " - "provided when `mode == 'constant'`. " - f"Received: mode={mode}" - ) - kwargs["constant_values"] = constant_values - return np.pad(x, pad_width, mode=mode, **kwargs) + raise NotImplementedError( + "`pad` is not supported with openvino backend" + ) def prod(x, axis=None, keepdims=False, dtype=None): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - if dtype is None: - dtype = dtypes.result_type(x.dtype) - if dtype in ("bool", "int8", "int16"): - dtype = "int32" - elif dtype in ("uint8", "uint16"): - dtype = "uint32" - return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + raise NotImplementedError( + "`prod` is not supported with openvino backend" + ) def quantile(x, q, axis=None, method="linear", keepdims=False): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - - ori_dtype = standardize_dtype(x.dtype) - # np.quantile doesn't support bool - if ori_dtype == "bool": - x = x.astype(config.floatx()) - if ori_dtype == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - return np.quantile( - x, q, axis=axis, method=method, keepdims=keepdims - ).astype(dtype) + raise NotImplementedError( + "`quantile` is not supported with openvino backend" + ) def ravel(x): - return np.ravel(x) + raise NotImplementedError( + "`ravel` is not supported with openvino backend" + ) def real(x): - return np.real(x) + raise NotImplementedError( + "`real` is not supported with openvino backend" + ) def reciprocal(x): - return np.reciprocal(x) + raise NotImplementedError( + "`reciprocal` is not supported with openvino backend" + ) def repeat(x, repeats, axis=None): - return np.repeat(x, repeats, axis=axis) + raise NotImplementedError( + "`repeat` is not supported with openvino backend" + ) def reshape(x, newshape): - return np.reshape(x, newshape) + raise NotImplementedError( + "`reshape` is not supported with openvino backend" + ) def roll(x, shift, axis=None): - return np.roll(x, shift, axis=axis) + raise NotImplementedError( + "`roll` is not supported with openvino backend" + ) def sign(x): - return np.sign(x) + raise NotImplementedError( + "`sign` is not supported with openvino backend" + ) def sin(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.sin(x) + raise NotImplementedError( + "`sin` is not supported with openvino backend" + ) def sinh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.sinh(x) + raise NotImplementedError( + "`sinh` is not supported with openvino backend" + ) def size(x): - return np.size(x) + raise NotImplementedError( + "`size` is not supported with openvino backend" + ) def sort(x, axis=-1): - axis = standardize_axis_for_numpy(axis) - return np.sort(x, axis=axis) + raise NotImplementedError( + "`sort` is not supported with openvino backend" + ) def split(x, indices_or_sections, axis=0): - axis = standardize_axis_for_numpy(axis) - return np.split(x, indices_or_sections, axis=axis) + raise NotImplementedError( + "`split` is not supported with openvino backend" + ) def stack(x, axis=0): - axis = standardize_axis_for_numpy(axis) - dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) - if len(dtype_set) > 1: - dtype = dtypes.result_type(*dtype_set) - x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x) - return np.stack(x, axis=axis) + raise NotImplementedError( + "`stack` is not supported with openvino backend" + ) def std(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - ori_dtype = standardize_dtype(x.dtype) - if "int" in ori_dtype or ori_dtype == "bool": - x = x.astype(config.floatx()) - return np.std(x, axis=axis, keepdims=keepdims) + raise NotImplementedError( + "`std` is not supported with openvino backend" + ) def swapaxes(x, axis1, axis2): - return np.swapaxes(x, axis1=axis1, axis2=axis2) + raise NotImplementedError( + "`swapaxes` is not supported with openvino backend" + ) def take(x, indices, axis=None): - axis = standardize_axis_for_numpy(axis) - return np.take(x, indices, axis=axis) + raise NotImplementedError( + "`take` is not supported with openvino backend" + ) def take_along_axis(x, indices, axis=None): - axis = standardize_axis_for_numpy(axis) - return np.take_along_axis(x, indices, axis=axis) + raise NotImplementedError( + "`take_along_axis` is not supported with openvino backend" + ) def tan(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.tan(x) + raise NotImplementedError( + "`tan` is not supported with openvino backend" + ) def tanh(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "int64": - dtype = config.floatx() - else: - dtype = dtypes.result_type(x.dtype, float) - x = x.astype(dtype) - return np.tanh(x) + raise NotImplementedError( + "`tanh` is not supported with openvino backend" + ) def tensordot(x1, x2, axes=2): - axes = tuple(axes) if isinstance(axes, list) else axes - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.tensordot(x1, x2, axes=axes) + raise NotImplementedError( + "`tensordot` is not supported with openvino backend" + ) def round(x, decimals=0): - return np.round(x, decimals=decimals) + raise NotImplementedError( + "`round` is not supported with openvino backend" + ) def tile(x, repeats): - return np.tile(x, repeats) + raise NotImplementedError( + "`tile` is not supported with openvino backend" + ) def trace(x, offset=0, axis1=0, axis2=1): - axis1 = standardize_axis_for_numpy(axis1) - axis2 = standardize_axis_for_numpy(axis2) - x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) - if dtype not in ("int64", "uint32", "uint64"): - dtype = dtypes.result_type(dtype, "int32") - return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + raise NotImplementedError( + "`trace` is not supported with openvino backend" + ) def tri(N, M=None, k=0, dtype=None): - dtype = dtype or config.floatx() - return np.tri(N, M=M, k=k, dtype=dtype) + raise NotImplementedError( + "`tri` is not supported with openvino backend" + ) def tril(x, k=0): - return np.tril(x, k=k) + raise NotImplementedError( + "`tril` is not supported with openvino backend" + ) def triu(x, k=0): - return np.triu(x, k=k) + raise NotImplementedError( + "`triu` is not supported with openvino backend" + ) def vdot(x1, x2): - x1 = convert_to_tensor(x1) - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type(x1.dtype, x2.dtype) - x1 = x1.astype(dtype) - x2 = x2.astype(dtype) - return np.vdot(x1, x2) + raise NotImplementedError( + "`vdot` is not supported with openvino backend" + ) def vstack(xs): - dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) - if len(dtype_set) > 1: - dtype = dtypes.result_type(*dtype_set) - xs = tree.map_structure( - lambda x: convert_to_tensor(x).astype(dtype), xs - ) - return np.vstack(xs) + raise NotImplementedError( + "`vstack` is not supported with openvino backend" + ) def vectorize(pyfunc, *, excluded=None, signature=None): - return np.vectorize(pyfunc, excluded=excluded, signature=signature) + raise NotImplementedError( + "`vectorize` is not supported with openvino backend" + ) def where(condition, x1, x2): - if x1 is not None and x2 is not None: - if not isinstance(x1, (int, float)): - x1 = convert_to_tensor(x1) - if not isinstance(x2, (int, float)): - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type( - getattr(x1, "dtype", type(x1)), - getattr(x2, "dtype", type(x2)), - ) - x1 = convert_to_tensor(x1, dtype) - x2 = convert_to_tensor(x2, dtype) - return np.where(condition, x1, x2) - else: - return np.where(condition) + raise NotImplementedError( + "`where` is not supported with openvino backend" + ) def divide(x1, x2): @@ -962,18 +776,9 @@ def divide(x1, x2): def divide_no_nan(x1, x2): - if not isinstance(x1, (int, float)): - x1 = convert_to_tensor(x1) - if not isinstance(x2, (int, float)): - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type( - getattr(x1, "dtype", type(x1)), - getattr(x2, "dtype", type(x2)), - float, + raise NotImplementedError( + "`divide_no_nan` is not supported with openvino backend" ) - x1 = convert_to_tensor(x1, dtype) - x2 = convert_to_tensor(x2, dtype) - return np.where(x2 == 0, 0, np.divide(x1, x2)) def true_divide(x1, x2): @@ -992,70 +797,51 @@ def negative(x): def square(x): - x = convert_to_tensor(x) - if standardize_dtype(x.dtype) == "bool": - x = x.astype("int32") - return np.square(x) + raise NotImplementedError( + "`square` is not supported with openvino backend" + ) def sqrt(x): - x = convert_to_tensor(x) - # upcast to float64 for int64 which matches JAX's behavior - dtype = ( - config.floatx() - if standardize_dtype(x.dtype) == "int64" - else dtypes.result_type(x.dtype, float) + raise NotImplementedError( + "`sqrt` is not supported with openvino backend" ) - return np.sqrt(x, dtype=dtype) def squeeze(x, axis=None): - axis = standardize_axis_for_numpy(axis) - return np.squeeze(x, axis=axis) + raise NotImplementedError( + "`squeeze` is not supported with openvino backend" + ) def transpose(x, axes=None): - axes = tuple(axes) if isinstance(axes, list) else axes - return np.transpose(x, axes=axes) + raise NotImplementedError( + "`transpose` is not supported with openvino backend" + ) def var(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - x = convert_to_tensor(x) - compute_dtype = dtypes.result_type(x.dtype, "float32") - result_dtype = dtypes.result_type(x.dtype, float) - return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype( - result_dtype + raise NotImplementedError( + "`var` is not supported with openvino backend" ) def sum(x, axis=None, keepdims=False): - axis = standardize_axis_for_numpy(axis) - dtype = standardize_dtype(x.dtype) - # follow jax's rule - if dtype in ("bool", "int8", "int16"): - dtype = "int32" - elif dtype in ("uint8", "uint16"): - dtype = "uint32" - return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype) + raise NotImplementedError( + "`sum` is not supported with openvino backend" + ) def eye(N, M=None, k=0, dtype=None): - dtype = dtype or config.floatx() - return np.eye(N, M=M, k=k, dtype=dtype) + raise NotImplementedError( + "`eye` is not supported with openvino backend" + ) def floor_divide(x1, x2): - if not isinstance(x1, (int, float)): - x1 = convert_to_tensor(x1) - if not isinstance(x2, (int, float)): - x2 = convert_to_tensor(x2) - dtype = dtypes.result_type( - getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) - ) - x1 = convert_to_tensor(x1, dtype) - x2 = convert_to_tensor(x2, dtype) - return np.floor_divide(x1, x2) + raise NotImplementedError( + "`floor_divide` is not supported with openvino backend" + ) def logical_xor(x1, x2): From c2c18082ce46973f4ce74a6b6f2e4ecc314d0c6b Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 23 Sep 2024 21:37:38 +0400 Subject: [PATCH 03/19] Mark unsupported ops in core, image, and linalg spaces Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 2 +- keras/src/backend/openvino/image.py | 394 ++------------------------- keras/src/backend/openvino/layer.py | 2 +- keras/src/backend/openvino/linalg.py | 80 +++--- 4 files changed, 60 insertions(+), 418 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index e1d249508a3..ea4cec6001a 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -254,5 +254,5 @@ def unstack(x, num=None, axis=0): def custom_gradient(fun): raise NotImplementedError( - "`custom_gradient` is not supported with numpy backend" + "`custom_gradient` is not supported with openvino backend" ) diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py index d385da3d834..86c59e51231 100644 --- a/keras/src/backend/openvino/image.py +++ b/keras/src/backend/openvino/image.py @@ -1,385 +1,41 @@ -import jax -import numpy as np - -from keras.src.backend.openvino.core import convert_to_tensor -from keras.src.utils.module_utils import scipy - -RESIZE_INTERPOLATIONS = ( - "bilinear", - "nearest", - "lanczos3", - "lanczos5", - "bicubic", -) - - def rgb_to_grayscale(image, data_format="channels_last"): - if data_format == "channels_first": - if len(image.shape) == 4: - image = np.transpose(image, (0, 2, 3, 1)) - elif len(image.shape) == 3: - image = np.transpose(image, (1, 2, 0)) - else: - raise ValueError( - "Invalid input rank: expected rank 3 (single image) " - "or rank 4 (batch of images). Received input with shape: " - f"image.shape={image.shape}" - ) - red, green, blue = image[..., 0], image[..., 1], image[..., 2] - grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue - grayscale_image = np.expand_dims(grayscale_image, axis=-1) - if data_format == "channels_first": - if len(image.shape) == 4: - grayscale_image = np.transpose(grayscale_image, (0, 3, 1, 2)) - elif len(image.shape) == 3: - grayscale_image = np.transpose(grayscale_image, (2, 0, 1)) - return np.array(grayscale_image) + raise NotImplementedError( + "`rgb_to_grayscale` is not supported with openvino backend" + ) def resize( - image, - size, - interpolation="bilinear", - antialias=False, - crop_to_aspect_ratio=False, - pad_to_aspect_ratio=False, - fill_mode="constant", - fill_value=0.0, - data_format="channels_last", + image, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format="channels_last", ): - if interpolation not in RESIZE_INTERPOLATIONS: - raise ValueError( - "Invalid value for argument `interpolation`. Expected of one " - f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" - ) - if fill_mode != "constant": - raise ValueError( - "Invalid value for argument `fill_mode`. Only `'constant'` " - f"is supported. Received: fill_mode={fill_mode}" - ) - if pad_to_aspect_ratio and crop_to_aspect_ratio: - raise ValueError( - "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " - "can be `True`." - ) - if not len(size) == 2: - raise ValueError( - "Argument `size` must be a tuple of two elements " - f"(height, width). Received: size={size}" - ) - size = tuple(size) - target_height, target_width = size - if len(image.shape) == 4: - if data_format == "channels_last": - size = (image.shape[0],) + size + (image.shape[-1],) - else: - size = (image.shape[0], image.shape[1]) + size - elif len(image.shape) == 3: - if data_format == "channels_last": - size = size + (image.shape[-1],) - else: - size = (image.shape[0],) + size - else: - raise ValueError( - "Invalid input rank: expected rank 3 (single image) " - "or rank 4 (batch of images). Received input with shape: " - f"image.shape={image.shape}" - ) - - if crop_to_aspect_ratio: - shape = image.shape - if data_format == "channels_last": - height, width = shape[-3], shape[-2] - else: - height, width = shape[-2], shape[-1] - crop_height = int(float(width * target_height) / target_width) - crop_height = min(height, crop_height) - crop_width = int(float(height * target_width) / target_height) - crop_width = min(width, crop_width) - crop_box_hstart = int(float(height - crop_height) / 2) - crop_box_wstart = int(float(width - crop_width) / 2) - if data_format == "channels_last": - if len(image.shape) == 4: - image = image[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - image = image[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - if len(image.shape) == 4: - image = image[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - else: - image = image[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - elif pad_to_aspect_ratio: - shape = image.shape - batch_size = image.shape[0] - if data_format == "channels_last": - height, width, channels = shape[-3], shape[-2], shape[-1] - else: - channels, height, width = shape[-3], shape[-2], shape[-1] - pad_height = int(float(width * target_height) / target_width) - pad_height = max(height, pad_height) - pad_width = int(float(height * target_width) / target_height) - pad_width = max(width, pad_width) - img_box_hstart = int(float(pad_height - height) / 2) - img_box_wstart = int(float(pad_width - width) / 2) - if data_format == "channels_last": - if len(image.shape) == 4: - padded_img = ( - np.ones( - ( - batch_size, - pad_height + height, - pad_width + width, - channels, - ), - dtype=image.dtype, - ) - * fill_value - ) - padded_img[ - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - :, - ] = image - else: - padded_img = ( - np.ones( - (pad_height + height, pad_width + width, channels), - dtype=image.dtype, - ) - * fill_value - ) - padded_img[ - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - :, - ] = image - else: - if len(image.shape) == 4: - padded_img = ( - np.ones( - ( - batch_size, - channels, - pad_height + height, - pad_width + width, - ), - dtype=image.dtype, - ) - * fill_value - ) - padded_img[ - :, - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = image - else: - padded_img = ( - np.ones( - (channels, pad_height + height, pad_width + width), - dtype=image.dtype, - ) - * fill_value - ) - padded_img[ - :, - img_box_hstart : img_box_hstart + height, - img_box_wstart : img_box_wstart + width, - ] = image - image = padded_img - - return np.array( - jax.image.resize(image, size, method=interpolation, antialias=antialias) + raise NotImplementedError( + "`resize` is not supported with openvino backend" ) -AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order - "nearest": 0, - "bilinear": 1, -} -AFFINE_TRANSFORM_FILL_MODES = { - "constant", - "nearest", - "wrap", - "mirror", - "reflect", -} - - def affine_transform( - image, - transform, - interpolation="bilinear", - fill_mode="constant", - fill_value=0, - data_format="channels_last", + image, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format="channels_last", ): - if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): - raise ValueError( - "Invalid value for argument `interpolation`. Expected of one " - f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " - f"interpolation={interpolation}" - ) - if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: - raise ValueError( - "Invalid value for argument `fill_mode`. Expected of one " - f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" - ) - - transform = convert_to_tensor(transform) - - if len(image.shape) not in (3, 4): - raise ValueError( - "Invalid image rank: expected rank 3 (single image) " - "or rank 4 (batch of images). Received input with shape: " - f"image.shape={image.shape}" - ) - if len(transform.shape) not in (1, 2): - raise ValueError( - "Invalid transform rank: expected rank 1 (single transform) " - "or rank 2 (batch of transforms). Received input with shape: " - f"transform.shape={transform.shape}" - ) - - # scipy.ndimage.map_coordinates lacks support for half precision. - input_dtype = image.dtype - if input_dtype == "float16": - image = image.astype("float32") - - # unbatched case - need_squeeze = False - if len(image.shape) == 3: - image = np.expand_dims(image, axis=0) - need_squeeze = True - if len(transform.shape) == 1: - transform = np.expand_dims(transform, axis=0) - - if data_format == "channels_first": - image = np.transpose(image, (0, 2, 3, 1)) - - batch_size = image.shape[0] - - # get indices - meshgrid = np.meshgrid( - *[np.arange(size) for size in image.shape[1:]], indexing="ij" - ) - indices = np.concatenate( - [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 - ) - indices = np.tile(indices, (batch_size, 1, 1, 1, 1)) - - # swap the values - a0 = transform[:, 0].copy() - a2 = transform[:, 2].copy() - b1 = transform[:, 4].copy() - b2 = transform[:, 5].copy() - transform[:, 0] = b1 - transform[:, 2] = b2 - transform[:, 4] = a0 - transform[:, 5] = a2 - - # deal with transform - transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1) - transform = np.reshape(transform, (batch_size, 3, 3)) - offset = transform[:, 0:2, 2].copy() - offset = np.pad(offset, pad_width=[[0, 0], [0, 1]]) - transform[:, 0:2, 2] = 0 - - # transform the indices - coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform) - coordinates = np.moveaxis(coordinates, source=-1, destination=1) - coordinates += np.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1)) - - # apply affine transformation - affined = np.stack( - [ - map_coordinates( - image[i], - coordinates[i], - order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], - fill_mode=fill_mode, - fill_value=fill_value, - ) - for i in range(batch_size) - ], - axis=0, + raise NotImplementedError( + "`affine_transform` is not supported with openvino backend" ) - if data_format == "channels_first": - affined = np.transpose(affined, (0, 3, 1, 2)) - if need_squeeze: - affined = np.squeeze(affined, axis=0) - if input_dtype == "float16": - affined = affined.astype(input_dtype) - return affined - - -MAP_COORDINATES_FILL_MODES = { - "constant", - "nearest", - "wrap", - "mirror", - "reflect", -} - def map_coordinates( - input, coordinates, order, fill_mode="constant", fill_value=0.0 + input, coordinates, order, fill_mode="constant", fill_value=0.0 ): - if fill_mode not in MAP_COORDINATES_FILL_MODES: - raise ValueError( - "Invalid value for argument `fill_mode`. Expected one of " - f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " - f"fill_mode={fill_mode}" - ) - if order not in range(2): - raise ValueError( - "Invalid value for argument `order`. Expected one of " - f"{[0, 1]}. Received: order={order}" - ) - # SciPy's implementation of map_coordinates handles boundaries incorrectly, - # unless mode='reflect'. For order=1, this only affects interpolation - # outside the bounds of the original array. - # https://github.com/scipy/scipy/issues/2640 - padding = [ - ( - max(-np.floor(c.min()).astype(int) + 1, 0), - max(np.ceil(c.max()).astype(int) + 1 - size, 0), - ) - for c, size in zip(coordinates, input.shape) - ] - shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] - pad_mode = { - "nearest": "edge", - "mirror": "reflect", - "reflect": "symmetric", - }.get(fill_mode, fill_mode) - if fill_mode == "constant": - padded = np.pad( - input, padding, mode=pad_mode, constant_values=fill_value - ) - else: - padded = np.pad(input, padding, mode=pad_mode) - result = scipy.ndimage.map_coordinates( - padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + raise NotImplementedError( + "`map_coordinates` is not supported with openvino backend" ) - return result diff --git a/keras/src/backend/openvino/layer.py b/keras/src/backend/openvino/layer.py index 08b761f972e..334c32958a7 100644 --- a/keras/src/backend/openvino/layer.py +++ b/keras/src/backend/openvino/layer.py @@ -1,2 +1,2 @@ -class NumpyLayer: +class OpenvinoLayer: pass diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index c925f3fbee0..9675e9fa4be 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,88 +1,74 @@ import numpy as np -import scipy.linalg as sl - -from keras.src.backend import standardize_dtype -from keras.src.backend.common import dtypes from keras.src.backend.openvino.core import convert_to_tensor def cholesky(a): - return np.linalg.cholesky(a) + raise NotImplementedError( + "`cholesky` is not supported with openvino backend" + ) def det(a): - return np.linalg.det(a) + raise NotImplementedError( + "`det` is not supported with openvino backend" + ) def eig(a): - return np.linalg.eig(a) + raise NotImplementedError( + "`eig` is not supported with openvino backend" + ) def eigh(a): - return np.linalg.eigh(a) + raise NotImplementedError( + "`eigh` is not supported with openvino backend" + ) def inv(a): - return np.linalg.inv(a) + raise NotImplementedError( + "`inv` is not supported with openvino backend" + ) def lu_factor(a): - if a.ndim == 2: - return sl.lu_factor(a) - - m, n = a.shape[-2:] - signature = "(m,n) -> (m,n), " - signature += "(m)" if m <= n else "(n)" - _lu_factor_gufunc = np.vectorize( - sl.lu_factor, - signature=signature, + raise NotImplementedError( + "`lu_factor` is not supported with openvino backend" ) - return _lu_factor_gufunc(a) def norm(x, ord=None, axis=None, keepdims=False): - x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) - if "int" in dtype or dtype == "bool": - dtype = dtypes.result_type(x.dtype, "float32") - return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( - dtype + raise NotImplementedError( + "`norm` is not supported with openvino backend" ) def qr(x, mode="reduced"): - if mode not in {"reduced", "complete"}: - raise ValueError( - "`mode` argument value not supported. " - "Expected one of {'reduced', 'complete'}. " - f"Received: mode={mode}" - ) - return np.linalg.qr(x, mode=mode) + raise NotImplementedError( + "`qr` is not supported with openvino backend" + ) def solve(a, b): - return np.linalg.solve(a, b) + raise NotImplementedError( + "`solve` is not supported with openvino backend" + ) def solve_triangular(a, b, lower=False): - if a.ndim == 2: - return sl.solve_triangular(a, b, lower=lower) - - _vectorized_solve_triangular = np.vectorize( - lambda a, b: sl.solve_triangular(a, b, lower=lower), - signature="(n,n),(n,m)->(n,m)", + raise NotImplementedError( + "`solve_triangular` is not supported with openvino backend" ) - if b.ndim == a.ndim - 1: - b = np.expand_dims(b, axis=-1) - return _vectorized_solve_triangular(a, b).squeeze(axis=-1) - return _vectorized_solve_triangular(a, b) def svd(x, full_matrices=True, compute_uv=True): - return np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + raise NotImplementedError( + "`svd` is not supported with openvino backend" + ) def lstsq(a, b, rcond=None): - a = convert_to_tensor(a) - b = convert_to_tensor(b) - return np.linalg.lstsq(a, b, rcond=rcond)[0] + raise NotImplementedError( + "`lstsq` is not supported with openvino backend" + ) From d26c22c3341a8c8e0299383abc92c2744337a369 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Mon, 23 Sep 2024 22:04:09 +0400 Subject: [PATCH 04/19] Mark unsupported ops in math, nn, random, and rnn spaces Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/math.py | 327 ++--------- keras/src/backend/openvino/nn.py | 811 ++------------------------- keras/src/backend/openvino/random.py | 120 +--- keras/src/backend/openvino/rnn.py | 247 +------- 4 files changed, 158 insertions(+), 1347 deletions(-) diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py index 81e5d71639d..23cf17bb3a6 100644 --- a/keras/src/backend/openvino/math.py +++ b/keras/src/backend/openvino/math.py @@ -1,329 +1,116 @@ -import numpy as np - -from keras.src.backend import standardize_dtype -from keras.src.backend.common import dtypes -from keras.src.backend.jax.math import fft as jax_fft -from keras.src.backend.jax.math import fft2 as jax_fft2 -from keras.src.backend.openvino.core import convert_to_tensor -from keras.src.utils.module_utils import scipy - - def segment_sum(data, segment_ids, num_segments=None, sorted=False): - if num_segments is None: - num_segments = np.amax(segment_ids) + 1 - - valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 - valid_data = data[valid_indices] - valid_segment_ids = segment_ids[valid_indices] - - data_shape = list(valid_data.shape) - data_shape[0] = ( - num_segments # Replace first dimension (which corresponds to segments) + raise NotImplementedError( + "`segment_sum` is not supported with openvino backend" ) - if sorted: - result = np.zeros(data_shape, dtype=valid_data.dtype) - np.add.at(result, valid_segment_ids, valid_data) - else: - sort_indices = np.argsort(valid_segment_ids) - sorted_segment_ids = valid_segment_ids[sort_indices] - sorted_data = valid_data[sort_indices] - - result = np.zeros(data_shape, dtype=valid_data.dtype) - np.add.at(result, sorted_segment_ids, sorted_data) - - return result - def segment_max(data, segment_ids, num_segments=None, sorted=False): - if num_segments is None: - num_segments = np.amax(segment_ids) + 1 - - valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 - valid_data = data[valid_indices] - valid_segment_ids = segment_ids[valid_indices] - - data_shape = list(valid_data.shape) - data_shape[0] = ( - num_segments # Replace first dimension (which corresponds to segments) + raise NotImplementedError( + "`segment_max` is not supported with openvino backend" ) - if sorted: - result = np.zeros(data_shape, dtype=valid_data.dtype) - np.maximum.at(result, valid_segment_ids, valid_data) - else: - sort_indices = np.argsort(valid_segment_ids) - sorted_segment_ids = valid_segment_ids[sort_indices] - sorted_data = valid_data[sort_indices] - - result = np.zeros(data_shape, dtype=valid_data.dtype) - np.maximum.at(result, sorted_segment_ids, sorted_data) - - return result - def top_k(x, k, sorted=False): - sorted_indices = np.argsort(x, axis=-1)[..., ::-1] - sorted_values = np.sort(x, axis=-1)[..., ::-1] - - if sorted: - # Take the k largest values. - top_k_values = sorted_values[..., :k] - top_k_indices = sorted_indices[..., :k] - else: - # Partition the array such that all values larger than the k-th - # largest value are to the right of it. - top_k_values = np.partition(x, -k, axis=-1)[..., -k:] - top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:] - - # Get the indices in sorted order. - idx = np.argsort(-top_k_values, axis=-1) - - # Get the top k values and their indices. - top_k_values = np.take_along_axis(top_k_values, idx, axis=-1) - top_k_indices = np.take_along_axis(top_k_indices, idx, axis=-1) - - return top_k_values, top_k_indices + raise NotImplementedError( + "`top_k` is not supported with openvino backend" + ) def in_top_k(targets, predictions, k): - targets = targets[:, None] - topk_values = top_k(predictions, k)[0] - targets_values = np.take_along_axis(predictions, targets, axis=-1) - mask = targets_values >= topk_values - return np.any(mask, axis=-1) + raise NotImplementedError( + "`in_top_k` is not supported with openvino backend" + ) def logsumexp(x, axis=None, keepdims=False): - max_x = np.max(x, axis=axis, keepdims=True) - result = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + max_x - return np.squeeze(result) if not keepdims else result + raise NotImplementedError( + "`logsumexp` is not supported with openvino backend" + ) def qr(x, mode="reduced"): - if mode not in {"reduced", "complete"}: - raise ValueError( - "`mode` argument value not supported. " - "Expected one of {'reduced', 'complete'}. " - f"Received: mode={mode}" - ) - return np.linalg.qr(x, mode=mode) + raise NotImplementedError( + "`qr` is not supported with openvino backend" + ) def extract_sequences(x, sequence_length, sequence_stride): - *batch_shape, _ = x.shape - batch_shape = list(batch_shape) - shape = x.shape[:-1] + ( - (x.shape[-1] - (sequence_length - sequence_stride)) // sequence_stride, - sequence_length, + raise NotImplementedError( + "`extract_sequences` is not supported with openvino backend" ) - strides = x.strides[:-1] + ( - sequence_stride * x.strides[-1], - x.strides[-1], - ) - x = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) - return np.reshape(x, (*batch_shape, *x.shape[-2:])) - - -def _get_complex_tensor_from_tuple(x): - if not isinstance(x, (tuple, list)) or len(x) != 2: - raise ValueError( - "Input `x` should be a tuple of two tensors - real and imaginary." - f"Received: x={x}" - ) - # `convert_to_tensor` does not support passing complex tensors. We separate - # the input out into real and imaginary and convert them separately. - real, imag = x - # Check shapes. - if real.shape != imag.shape: - raise ValueError( - "Input `x` should be a tuple of two tensors - real and imaginary." - "Both the real and imaginary parts should have the same shape. " - f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" - ) - # Ensure dtype is float. - if not np.issubdtype(real.dtype, np.floating) or not np.issubdtype( - imag.dtype, np.floating - ): - raise ValueError( - "At least one tensor in input `x` is not of type float." - f"Received: x={x}." - ) - complex_input = real + 1j * imag - return complex_input def fft(x): - real, imag = jax_fft(x) - return np.array(real), np.array(imag) + raise NotImplementedError( + "`fft` is not supported with openvino backend" + ) def fft2(x): - real, imag = jax_fft2(x) - return np.array(real), np.array(imag) + raise NotImplementedError( + "`fft2` is not supported with openvino backend" + ) def rfft(x, fft_length=None): - complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward") - # numpy always outputs complex128, so we need to recast the dtype - return ( - np.real(complex_output).astype(x.dtype), - np.imag(complex_output).astype(x.dtype), + raise NotImplementedError( + "`rfft` is not supported with openvino backend" ) def irfft(x, fft_length=None): - complex_input = _get_complex_tensor_from_tuple(x) - # numpy always outputs float64, so we need to recast the dtype - return np.fft.irfft( - complex_input, n=fft_length, axis=-1, norm="backward" - ).astype(x[0].dtype) + raise NotImplementedError( + "`irfft` is not supported with openvino backend" + ) def stft( - x, sequence_length, sequence_stride, fft_length, window="hann", center=True + x, sequence_length, sequence_stride, fft_length, window="hann", center=True ): - if standardize_dtype(x.dtype) not in {"float32", "float64"}: - raise TypeError( - "Invalid input type. Expected `float32` or `float64`. " - f"Received: input type={x.dtype}" - ) - if fft_length < sequence_length: - raise ValueError( - "`fft_length` must equal or larger than `sequence_length`. " - f"Received: sequence_length={sequence_length}, " - f"fft_length={fft_length}" - ) - if isinstance(window, str): - if window not in {"hann", "hamming"}: - raise ValueError( - "If a string is passed to `window`, it must be one of " - f'`"hann"`, `"hamming"`. Received: window={window}' - ) - x = convert_to_tensor(x) - ori_dtype = x.dtype - - if center: - pad_width = [(0, 0) for _ in range(len(x.shape))] - pad_width[-1] = (fft_length // 2, fft_length // 2) - x = np.pad(x, pad_width, mode="reflect") - - l_pad = (fft_length - sequence_length) // 2 - r_pad = fft_length - sequence_length - l_pad - - if window is not None: - if isinstance(window, str): - win = convert_to_tensor( - scipy.signal.get_window(window, sequence_length), dtype=x.dtype - ) - else: - win = convert_to_tensor(window, dtype=x.dtype) - if len(win.shape) != 1 or win.shape[-1] != sequence_length: - raise ValueError( - "The shape of `window` must be equal to [sequence_length]." - f"Received: window shape={win.shape}" - ) - win = np.pad(win, [[l_pad, r_pad]]) - else: - win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype) - - x = scipy.signal.stft( - x, - fs=1.0, - window=win, - nperseg=(sequence_length + l_pad + r_pad), - noverlap=(sequence_length + l_pad + r_pad - sequence_stride), - nfft=fft_length, - boundary=None, - padded=False, - )[-1] - - # scale and swap to (..., num_sequences, fft_bins) - x = x / np.sqrt(1.0 / win.sum() ** 2) - x = np.swapaxes(x, -2, -1) - return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype) + raise NotImplementedError( + "`stft` is not supported with openvino backend" + ) def istft( - x, - sequence_length, - sequence_stride, - fft_length, - length=None, - window="hann", - center=True, -): - x = _get_complex_tensor_from_tuple(x) - dtype = np.real(x).dtype - - expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) - l_pad = (fft_length - sequence_length) // 2 - r_pad = fft_length - sequence_length - l_pad - - if window is not None: - if isinstance(window, str): - win = convert_to_tensor( - scipy.signal.get_window(window, sequence_length), dtype=dtype - ) - else: - win = convert_to_tensor(window, dtype=dtype) - if len(win.shape) != 1 or win.shape[-1] != sequence_length: - raise ValueError( - "The shape of `window` must be equal to [sequence_length]." - f"Received: window shape={win.shape}" - ) - win = np.pad(win, [[l_pad, r_pad]]) - else: - win = np.ones((sequence_length + l_pad + r_pad), dtype=dtype) - - x = scipy.signal.istft( x, - fs=1.0, - window=win, - nperseg=(sequence_length + l_pad + r_pad), - noverlap=(sequence_length + l_pad + r_pad - sequence_stride), - nfft=fft_length, - boundary=False, - time_axis=-2, - freq_axis=-1, - )[-1] - - # scale - x = x / win.sum() if window is not None else x / sequence_stride - - start = 0 if center is False else fft_length // 2 - if length is not None: - end = start + length - elif center is True: - end = -(fft_length // 2) - else: - end = expected_output_len - return x[..., start:end] + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + raise NotImplementedError( + "`istft` is not supported with openvino backend" + ) def rsqrt(x): - return 1.0 / np.sqrt(x) + raise NotImplementedError( + "`rsqrt` is not supported with openvino backend" + ) def erf(x): - return np.array(scipy.special.erf(x)) + raise NotImplementedError( + "`erf` is not supported with openvino backend" + ) def erfinv(x): - return np.array(scipy.special.erfinv(x)) + raise NotImplementedError( + "`erfinv` is not supported with openvino backend" + ) def solve(a, b): - a = convert_to_tensor(a) - b = convert_to_tensor(b) - return np.linalg.solve(a, b) + raise NotImplementedError( + "`solve` is not supported with openvino backend" + ) def norm(x, ord=None, axis=None, keepdims=False): - x = convert_to_tensor(x) - dtype = standardize_dtype(x.dtype) - if "int" in dtype or dtype == "bool": - dtype = dtypes.result_type(x.dtype, "float32") - return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( - dtype + raise NotImplementedError( + "`norm` is not supported with openvino backend" ) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 0270bc0492e..26f9b59de2a 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -1,14 +1,4 @@ -import jax import numpy as np -from jax import lax - -from keras.src import backend -from keras.src.backend.common.backend_utils import ( - compute_conv_transpose_padding_args_for_jax, -) -from keras.src.backend.openvino.core import cast -from keras.src.backend.openvino.core import convert_to_tensor -from keras.src.backend.openvino.core import is_tensor def relu(x): @@ -48,13 +38,15 @@ def silu(x): def log_sigmoid(x): - from openvino.runtime.opset14 import softplus, negative - return negative(softplus(negative(x))) + raise NotImplementedError( + "`log_sigmoid` is not supported with openvino backend" + ) def leaky_relu(x, negative_slope=0.2): - x = convert_to_tensor(x) - return np.maximum(x, np.array(negative_slope, x.dtype) * x) + raise NotImplementedError( + "`leaky_relu` is not supported with openvino backend" + ) def hard_sigmoid(x): @@ -101,63 +93,6 @@ def log_softmax(x, axis=None): return log_softmax(x, axis) -def _convert_to_spatial_operand( - x, - num_spatial_dims, - data_format="channels_last", - include_batch_and_channels=True, -): - # Helper function that converts an operand to a spatial operand. - x = (x,) * num_spatial_dims if isinstance(x, int) else x - if not include_batch_and_channels: - return x - if data_format == "channels_last": - x = (1,) + x + (1,) - else: - x = (1,) + (1,) + x - return x - - -def _pool( - inputs, - initial_value, - reduce_fn, - pool_size, - strides=None, - padding="valid", -): - """Helper function to define pooling functions. - - Args: - inputs: input data of shape `N+2`. - initial_value: the initial value for the reduction. - reduce_fn: a reduce function of the form `(T, T) -> T`. - pool_size: a sequence of `N` integers, representing the window size to - reduce over. - strides: a sequence of `N` integers, representing the inter-window - strides (default: `(1, ..., 1)`). - padding: either the string `same` or `valid`. - - Returns: - The output of the reduction for each window slice. - """ - if padding not in ("same", "valid"): - raise ValueError( - f"Invalid padding '{padding}', must be 'same' or 'valid'." - ) - padding = padding.upper() - return np.array( - lax.reduce_window( - inputs, - initial_value, - reduce_fn, - pool_size, - strides, - padding, - ) - ) - - def max_pool( inputs, pool_size, @@ -165,16 +100,9 @@ def max_pool( padding="valid", data_format=None, ): - data_format = backend.standardize_data_format(data_format) - num_spatial_dims = inputs.ndim - 2 - pool_size = _convert_to_spatial_operand( - pool_size, num_spatial_dims, data_format - ) - strides = pool_size if strides is None else strides - strides = _convert_to_spatial_operand( - strides, num_spatial_dims, data_format + raise NotImplementedError( + "`max_pool` is not supported with openvino backend" ) - return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) def average_pool( @@ -184,61 +112,8 @@ def average_pool( padding, data_format=None, ): - data_format = backend.standardize_data_format(data_format) - num_spatial_dims = inputs.ndim - 2 - pool_size = _convert_to_spatial_operand( - pool_size, num_spatial_dims, data_format - ) - strides = pool_size if strides is None else strides - strides = _convert_to_spatial_operand( - strides, num_spatial_dims, data_format - ) - - pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) - if padding == "valid": - # Avoid the extra reduce_window. - return pooled / np.prod(pool_size) - else: - # Count the number of valid entries at each input point, then use that - # for computing average. Assumes that any two arrays of same shape will - # be padded the same. Avoid broadcasting on axis where pooling is - # skipped. - shape = [ - (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) - ] - window_counts = _pool( - np.ones(shape, inputs.dtype), - 0.0, - lax.add, - pool_size, - strides, - padding, - ) - return pooled / window_counts - - -def _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format="channels_last", - transpose=False, -): - """Create a `lax.ConvDimensionNumbers` for the given inputs.""" - num_dims = num_spatial_dims + 2 - - if data_format == "channels_last": - spatial_dims = tuple(range(1, num_dims - 1)) - inputs_dn = (0, num_dims - 1) + spatial_dims - else: - spatial_dims = tuple(range(2, num_dims)) - inputs_dn = (0, 1) + spatial_dims - - if transpose: - kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) - else: - kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) - - return lax.ConvDimensionNumbers( - lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn + raise NotImplementedError( + "`average_pool` is not supported with openvino backend" ) @@ -250,47 +125,8 @@ def conv( data_format=None, dilation_rate=1, ): - data_format = backend.standardize_data_format(data_format) - num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - if data_format == "channels_last": - channels = inputs.shape[-1] - else: - channels = inputs.shape[1] - kernel_in_channels = kernel.shape[-2] - if channels % kernel_in_channels > 0: - raise ValueError( - "The number of input channels must be evenly divisible by " - f"kernel's in_channels. Received input channels {channels} and " - f"kernel in_channels {kernel_in_channels}. " - ) - feature_group_count = channels // kernel_in_channels - return np.array( - jax.lax.conv_general_dilated( - inputs, - kernel if is_tensor(kernel) else kernel.numpy(), - strides, - padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - ) + raise NotImplementedError( + "`conv` is not supported with openvino backend" ) @@ -302,42 +138,8 @@ def depthwise_conv( data_format=None, dilation_rate=1, ): - data_format = backend.standardize_data_format(data_format) - num_spatial_dims = inputs.ndim - 2 - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - feature_group_count = ( - inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] - ) - kernel = np.reshape( - kernel if is_tensor(kernel) else kernel.numpy(), - kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), - ) - return np.array( - jax.lax.conv_general_dilated( - inputs, - kernel, - strides, - padding, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - ) + raise NotImplementedError( + "`depthwise_conv` is not supported with openvino backend" ) @@ -350,22 +152,8 @@ def separable_conv( data_format=None, dilation_rate=1, ): - data_format = backend.standardize_data_format(data_format) - depthwise_conv_output = depthwise_conv( - inputs, - depthwise_kernel, - strides, - padding, - data_format, - dilation_rate, - ) - return conv( - depthwise_conv_output, - pointwise_kernel, - strides=1, - padding="valid", - data_format=data_format, - dilation_rate=dilation_rate, + raise NotImplementedError( + "`separable_conv` is not supported with openvino backend" ) @@ -378,538 +166,59 @@ def conv_transpose( data_format=None, dilation_rate=1, ): - data_format = backend.standardize_data_format(data_format) - num_spatial_dims = inputs.ndim - 2 - padding_values = compute_conv_transpose_padding_args_for_jax( - input_shape=inputs.shape, - kernel_shape=kernel.shape, - strides=strides, - padding=padding, - output_padding=output_padding, - dilation_rate=dilation_rate, - ) - dimension_numbers = _convert_to_lax_conv_dimension_numbers( - num_spatial_dims, - data_format, - transpose=False, - ) - strides = _convert_to_spatial_operand( - strides, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - dilation_rate = _convert_to_spatial_operand( - dilation_rate, - num_spatial_dims, - data_format, - include_batch_and_channels=False, - ) - - return np.array( - jax.lax.conv_transpose( - inputs, - kernel if is_tensor(kernel) else kernel.numpy(), - strides, - padding=padding_values, - rhs_dilation=dilation_rate, - dimension_numbers=dimension_numbers, - transpose_kernel=True, - ) + raise NotImplementedError( + "`conv_transpose` is not supported with openvino backend" ) def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): - if sparse: - raise ValueError("Unsupported value `sparse=True` with numpy backend") - x = convert_to_tensor(x) - input_shape = x.shape - - # Shrink the last dimension if the shape is (..., 1). - if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: - input_shape = tuple(input_shape[:-1]) - - x = x.reshape(-1) - if not num_classes: - num_classes = np.max(x) + 1 - - batch_size = x.shape[0] - categorical = np.zeros((batch_size, num_classes), dtype=dtype) - valid_indices = x >= 0 - categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 - - # First, reshape the array with the extra dimension at the end - output_shape = input_shape + (num_classes,) - categorical = np.reshape(categorical, output_shape) - - # Then, move this new dimension to the right place (according to axis) - if axis != -1: - categorical = np.moveaxis(categorical, -1, axis) - - return categorical + raise NotImplementedError( + "`one_hot` is not supported with openvino backend" + ) def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): - if sparse: - raise ValueError("Unsupported value `sparse=True` with numpy backend") - x = convert_to_tensor(x) - reduction_axis = 1 if len(x.shape) > 1 else 0 - outputs = np.max( - one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), - axis=reduction_axis, + raise NotImplementedError( + "`multi_hot` is not supported with openvino backend" ) - return outputs def categorical_crossentropy(target, output, from_logits=False, axis=-1): - target = np.array(target) - output = np.array(output) - - if target.shape != output.shape: - raise ValueError( - "Arguments `target` and `output` must have the same shape. " - "Received: " - f"target.shape={target.shape}, output.shape={output.shape}" - ) - if len(target.shape) < 1: - raise ValueError( - "Arguments `target` and `output` must be at least rank 1. " - "Received: " - f"target.shape={target.shape}, output.shape={output.shape}" - ) - - if from_logits: - log_prob = log_softmax(output, axis=axis) - else: - output = output / np.sum(output, axis, keepdims=True) - output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) - log_prob = np.log(output) - return -np.sum(target * log_prob, axis=axis) + raise NotImplementedError( + "`categorical_crossentropy` is not supported with openvino backend" + ) def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): - target = np.array(target, dtype="int32") - output = np.array(output) - if len(target.shape) == len(output.shape) and target.shape[-1] == 1: - target = np.squeeze(target, axis=-1) - - if len(output.shape) < 1: - raise ValueError( - "Argument `output` must be at least rank 1. " - "Received: " - f"output.shape={output.shape}" - ) - if target.shape != output.shape[:-1]: - raise ValueError( - "Arguments `target` and `output` must have the same shape " - "up until the last dimension: " - f"target.shape={target.shape}, output.shape={output.shape}" - ) - if from_logits: - log_prob = log_softmax(output, axis=axis) - else: - output = output / np.sum(output, axis, keepdims=True) - output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) - log_prob = np.log(output) - target = one_hot(target, output.shape[axis], axis=axis) - return -np.sum(target * log_prob, axis=axis) + raise NotImplementedError( + "`sparse_categorical_crossentropy` is not supported with openvino backend" + ) def binary_crossentropy(target, output, from_logits=False): - target = np.array(target) - output = np.array(output) - - if target.shape != output.shape: - raise ValueError( - "Arguments `target` and `output` must have the same shape. " - "Received: " - f"target.shape={target.shape}, output.shape={output.shape}" - ) - - if from_logits: - output = sigmoid(output) - - output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) - bce = target * np.log(output) - bce += (1.0 - target) * np.log(1.0 - output) - return -bce + raise NotImplementedError( + "`binary_crossentropy` is not supported with openvino backend" + ) def moments(x, axes, keepdims=False, synchronized=False): - if synchronized: - raise NotImplementedError( - "Argument synchronized=True is not supported with NumPy." - ) - axes = tuple(axes) if isinstance(axes, list) else axes - # The dynamic range of float16 is too limited for statistics. As a - # workaround, we simply perform the operations on float32 and convert back - # to float16 - need_cast = False - ori_dtype = backend.standardize_dtype(x.dtype) - if ori_dtype == "float16": - need_cast = True - x = cast(x, "float32") - - mean = np.mean(x, axes, keepdims=True) - - # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster - # but less numerically stable. - variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) - - if not keepdims: - mean = np.squeeze(mean, axes) - variance = np.squeeze(variance, axes) - if need_cast: - # avoid overflow and underflow when casting from float16 to float32 - mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) - variance = np.clip( - variance, np.finfo(np.float16).min, np.finfo(np.float16).max - ) - mean = cast(mean, ori_dtype) - variance = cast(variance, ori_dtype) - return mean, variance + raise NotImplementedError( + "`moments` is not supported with openvino backend" + ) def batch_normalization( x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 ): - shape = [1] * len(x.shape) - shape[axis] = mean.shape[0] - mean = np.reshape(mean, shape) - variance = np.reshape(variance, shape) - - inv = 1.0 / np.sqrt(variance + epsilon) - if scale is not None: - scale = np.reshape(scale, shape) - inv = inv * scale - - res = -mean * inv - if offset is not None: - offset = np.reshape(offset, shape) - res = res + offset - - return x * inv + res - - -def ctc_loss(target, output, target_length, output_length, mask_index=0): - # Ref: https://github.com/google-deepmind/optax - # optax.ctc_loss_with_forward_probs - target = convert_to_tensor(target, dtype="int32") - output = convert_to_tensor(output) - target_length = convert_to_tensor(target_length, "int32") - output_length = convert_to_tensor(output_length, "int32") - batch_size, max_input_length, num_classes = output.shape - batch_size, max_label_length = target.shape - log_epsilon = -1e5 - - # Ensure that the dtype promotion behavior matchs that of `tf.nn.ctc_loss` - dtype = backend.result_type(output.dtype, "float32") - output = output.astype(dtype) - - def _lengths_to_paddings(lengths, max_length): - indices = np.arange(max_length).reshape( - (1,) * lengths.ndim + (max_length,) - ) - lengths = np.expand_dims(lengths, axis=-1) - elem_valid = indices < lengths - return np.logical_not(elem_valid) - - target_paddings = _lengths_to_paddings(target_length, max_label_length) - output_paddings = _lengths_to_paddings(output_length, max_input_length) - target_paddings = target_paddings.astype(output.dtype) - output_paddings = output_paddings.astype(output.dtype) - - logprobs = log_softmax(output, axis=-1) - label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( - np.int32 - ) - - # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. - repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) - repeat = np.pad(repeat, ((0, 0), (0, 1))) - - logprobs_phi = logprobs[:, :, mask_index: mask_index + 1] # [B, T, 1] - logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] - - _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] - logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) - logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] - - # [B, N] - logalpha_phi_init = ( - np.ones((batch_size, max_label_length + 1), dtype=output.dtype) - * log_epsilon - ) - logalpha_phi_init[:, 0] = 0.0 - logalpha_emit_init = ( - np.ones((batch_size, max_label_length), dtype=output.dtype) - * log_epsilon - ) - - def update_phi_score(phi, added_score): - # Update `phi[:, 1:]`` with adding `added_score` in log space. - return np.concatenate( - [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 - ) - - def loop_body(prev, x): - prev_phi, prev_emit = prev - # emit-to-phi epsilon transition, except if the next label is repetition - prev_phi_orig = prev_phi - prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) - - logprob_emit, logprob_phi, pad = x - - # phi-to-emit transition - next_emit = np.logaddexp( - prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit - ) - # self-loop transition - next_phi = prev_phi + logprob_phi - # emit-to-phi blank transition only when the next label is repetition - next_phi = update_phi_score( - next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) - ) - - pad = pad.reshape((batch_size, 1)) - next_emit = pad * prev_emit + (1.0 - pad) * next_emit - next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi - - return (next_phi, next_emit), (next_phi, next_emit) - - def np_scan(f, init, xs): - carry = init - ys = [] - for x in zip(*xs): - carry, y = f(carry, x) - ys.append(y) - result = [] - for i in range(len(ys[0])): - result.append(np.stack([y[i] for y in ys])) - return carry, result - - xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) - _, (logalpha_phi, logalpha_emit) = np_scan( - loop_body, (logalpha_phi_init, logalpha_emit_init), xs - ) - - # last row needs to be updated with the last epsilon transition - logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) - logalpha_phi[-1] = logalpha_phi_last - - # extract per_seq_loss - # [B, N+1] - _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) - per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) - return per_seq_loss - - -def _ctc_greedy_decode( - inputs, - sequence_lengths, - merge_repeated=True, - mask_index=None, -): - inputs = convert_to_tensor(inputs) - sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") - batch_size, max_length, num_classes = inputs.shape - - if mask_index is None: - mask_index = num_classes - 1 - - indices = np.argmax(inputs, axis=-1).astype("int32") - scores = np.max(inputs, axis=-1) - - seqlen_mask = np.arange(max_length)[None, :] - seqlen_mask = seqlen_mask >= sequence_lengths[:, None] - - indices = np.where(seqlen_mask, mask_index, indices) - scores = np.where(seqlen_mask, 0.0, scores) - - if merge_repeated: - repeat_mask = indices[:, 1:] == indices[:, :-1] - repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) - indices = np.where(repeat_mask, mask_index, indices) - - # We set to -1 for blank labels - invalid_mask = indices == mask_index - indices = np.where(invalid_mask, -1, indices) - - # We rearrange the indices by moving `mask_index` to the end of the array - order = np.expand_dims(np.arange(max_length), axis=0) # [1, N] - order = np.tile(order, (batch_size, 1)) # [B, N] - order = np.where(invalid_mask, max_length, order) - order = np.argsort(order, axis=-1) - indices = np.take_along_axis(indices, order, axis=-1) - - scores = -np.sum(scores, axis=1)[:, None] - indices = np.expand_dims(indices, axis=0) - return indices, scores - - -def _ctc_beam_search_decode( - inputs, - sequence_lengths, - beam_width=100, - top_paths=1, - mask_index=None, -): - inputs = convert_to_tensor(inputs) - sequence_lengths = convert_to_tensor(sequence_lengths) - - batch_size, max_seq_len, num_classes = inputs.shape - inputs = log_softmax(inputs, axis=-1) - seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] - - if mask_index is None: - mask_index = num_classes - 1 - - # This is a workaround for the fact that np.argsort does not support - # the order parameter which is used to break ties when scores are equal. - # For compatibility with the tensorflow implementation, we flip the inputs - # and the mask_index, and then flip the classes back to the correct indices - inputs = np.flip(inputs, axis=2) - mask_index = num_classes - mask_index - 1 - - _pad = -1 - - init_paths = np.full( - (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 + raise NotImplementedError( + "`batch_normalization` is not supported with openvino backend" ) - num_init_paths = np.min(np.array([num_classes, beam_width])) - max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] - init_classes = np.where(max_classes == mask_index, _pad, max_classes) - init_paths[:, :num_init_paths, 0] = init_classes - init_scores = np.full( - (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype - ) - init_scores[:, :num_init_paths] = np.take_along_axis( - inputs[:, 0], max_classes, axis=1 +def ctc_loss(target, output, target_length, output_length, mask_index=0): + raise NotImplementedError( + "`ctc_loss` is not supported with openvino backend" ) - init_masked = init_paths[:, :, 0] == _pad - - def _extend_paths(paths, scores, masked, x): - paths = np.repeat(paths, num_classes, axis=0) - scores = np.repeat(scores, num_classes) - masked = np.repeat(masked, num_classes) - - path_tail_index = np.argmax(paths == _pad, axis=1) - paths_arange = np.arange(2 * beam_width * num_classes) - path_tails = paths[paths_arange, path_tail_index - 1] - path_tails = np.where(path_tail_index == 0, _pad, path_tails) - - classes = np.arange(num_classes) - classes[mask_index] = _pad - classes = np.tile(classes, 2 * beam_width) - - prev_masked = masked - masked = classes == _pad - - masked_repeat = ~prev_masked & (path_tails == classes) - classes = np.where(masked_repeat, _pad, classes) - paths[paths_arange, path_tail_index] = classes - - x = np.tile(x, 2 * beam_width) - scores = scores + x - - return paths, scores, masked - - def _merge_scores(unique_inverse, scores): - scores_max = np.max(scores) - scores_exp = np.exp(scores - scores_max) - scores = np.zeros_like(scores) - for i, u in enumerate(unique_inverse): - scores[u] += scores_exp[i] - scores = np.log(scores) + scores_max - return scores - - def _prune_paths(paths, scores, masked): - paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) - pad_size = (2 * num_classes * beam_width) - len(paths) - if pad_size > 0: - paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) - paths = paths[: 2 * num_classes * beam_width] - if len(unique_inverse.shape) >= 2: - unique_inverse = np.squeeze(unique_inverse, axis=1) - - emit_scores = np.where(masked, -np.inf, scores) - mask_scores = np.where(masked, scores, -np.inf) - - emit_scores = _merge_scores(unique_inverse, emit_scores) - mask_scores = _merge_scores(unique_inverse, mask_scores) - - total_scores = np.logaddexp(emit_scores, mask_scores) - top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] - - paths = paths[top_indices] - emit_scores = emit_scores[top_indices] - mask_scores = mask_scores[top_indices] - - paths = np.tile(paths, (2, 1)) - scores = np.concatenate([emit_scores, mask_scores]) - masked = np.concatenate( - [np.zeros(beam_width, bool), np.ones(beam_width, bool)] - ) - - return paths, scores, masked - - def _decode_step(paths, scores, masked, x): - paths, scores, masked = _extend_paths(paths, scores, masked, x) - paths, scores, masked = _prune_paths(paths, scores, masked) - return paths, scores, masked - - def _step(prev, x): - paths, scores, masked = prev - x, seqlen_mask = x - if not seqlen_mask: - paths, scores, masked = _decode_step(paths, scores, masked, x) - return (paths, scores, masked), None - - def _decode_batch( - init_paths, init_scores, init_masked, inputs, seqlen_mask - ): - def np_scan_only_carry(f, init, xs): - carry = init - for x in zip(*xs): - carry, y = f(carry, x) - return carry, None - - (paths, scores, masked), _ = np_scan_only_carry( - _step, - (init_paths, init_scores, init_masked), - (inputs[1:], seqlen_mask[1:]), - ) - - paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) - pad_size = (2 * num_classes * beam_width) - len(paths) - if pad_size > 0: - paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) - paths = paths[: 2 * num_classes * beam_width] - if len(unique_inverse.shape) >= 2: - unique_inverse = np.squeeze(unique_inverse, axis=1) - scores = _merge_scores(unique_inverse, scores) - - top_indices = np.argsort(scores)[-top_paths:][::-1] - paths = paths[top_indices] - scores = scores[top_indices] - - return paths, scores - - results = [ - _decode_batch(p, s, m, i, sm) - for p, s, m, i, sm in zip( - init_paths, init_scores, init_masked, inputs, seqlen_mask - ) - ] - paths = np.stack([r[0] for r in results]) - scores = np.stack([r[1] for r in results]) - - # convert classes back to the correct indices - paths = np.where(paths == _pad, _pad, num_classes - paths - 1) - paths = np.transpose(paths, [1, 0, 2]) - return paths, scores def ctc_decode( @@ -921,40 +230,12 @@ def ctc_decode( merge_repeated=True, mask_index=0, ): - inputs = convert_to_tensor(inputs) - dtype = backend.result_type(inputs.dtype, "float32") - inputs = cast(inputs, dtype) - - if strategy == "greedy": - return _ctc_greedy_decode( - inputs, - sequence_lengths, - merge_repeated=merge_repeated, - mask_index=mask_index, - ) - elif strategy == "beam_search": - return _ctc_beam_search_decode( - inputs, - sequence_lengths, - beam_width=beam_width, - top_paths=top_paths, - mask_index=mask_index, - ) - else: - raise ValueError( - f"Invalid strategy {strategy}. Supported values are " - "'greedy' and 'beam_search'." - ) + raise NotImplementedError( + "`ctc_decode` is not supported with openvino backend" + ) def psnr(x1, x2, max_val): - if x1.shape != x2.shape: - raise ValueError( - f"Input shapes {x1.shape} and {x2.shape} must " - "match for PSNR calculation. " - ) - - max_val = convert_to_tensor(max_val, dtype=x2.dtype) - mse = np.mean(np.square(x1 - x2)) - psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) - return psnr + raise NotImplementedError( + "`psnr` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index 028544e53b7..18c70757578 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -1,120 +1,58 @@ -import numpy as np - -from keras.src.backend.config import floatx -from keras.src.backend.openvino.nn import softmax -from keras.src.random.seed_generator import SeedGenerator -from keras.src.random.seed_generator import draw_seed -from keras.src.random.seed_generator import make_default_seed - - def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) + raise NotImplementedError( + "`normal` is not supported with openvino backend" + ) def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype) + raise NotImplementedError( + "`uniform` is not supported with openvino backend" + ) def categorical(logits, num_samples, dtype="int64", seed=None): - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - output = [] - for logits_instance in logits: - probabilities = softmax(logits_instance) - classes = np.arange(logits_instance.shape[-1]) - samples = rng.choice(classes, size=num_samples, p=probabilities) - output.append(samples) - return np.array(output).astype(dtype) + raise NotImplementedError( + "`categorical` is not supported with openvino backend" + ) def randint(shape, minval, maxval, dtype="int32", seed=None): - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype) - return output + raise NotImplementedError( + "`randint` is not supported with openvino backend" + ) def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - - lower_bound = mean - 2 * stddev - upper_bound = mean + 2 * stddev - - flat_shape = np.prod(shape) - random_numbers = np.empty(0) - - # loop until we have enough valid numbers to fill our desired shape - while random_numbers.shape[0] < flat_shape: - # Generate a batch of random numbers from a normal distribution - batch = rng.normal(loc=mean, scale=stddev, size=flat_shape) - - # Filter the numbers to keep only those within the specified bounds - valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] - - # Append the valid numbers to the result array - random_numbers = np.append(random_numbers, valid) - - # Truncate the result array to the desired size and reshape it - return random_numbers[:flat_shape].astype(dtype).reshape(shape) + raise NotImplementedError( + "`truncated_normal` is not supported with openvino backend" + ) def dropout(inputs, rate, noise_shape=None, seed=None): - dtype = inputs.dtype - seed = draw_seed(seed) - - keep_prob = 1.0 - rate - - # If noise_shape is not provided, use the shape of inputs - if noise_shape is None: - noise_shape = inputs.shape - else: - # If noise_shape is provided, replace None with corresponding - # input shape - noise_shape = [ - n if n is not None else inputs.shape[i] - for i, n in enumerate(noise_shape) - ] - - rng = np.random.default_rng(seed) - mask = rng.uniform(size=noise_shape) < keep_prob - mask = np.broadcast_to(mask, inputs.shape) - return np.where( - mask, (inputs / keep_prob).astype(dtype), np.zeros_like(inputs) + raise NotImplementedError( + "`dropout` is not supported with openvino backend" ) def shuffle(x, axis=0, seed=None): - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - return rng.permuted(x, axis=axis) + raise NotImplementedError( + "`shuffle` is not supported with openvino backend" + ) def gamma(shape, alpha, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - return rng.gamma(alpha, scale=1.0, size=shape).astype(dtype) + raise NotImplementedError( + "`gamma` is not supported with openvino backend" + ) def binomial(shape, counts, probabilities, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - sample = rng.binomial(n=counts, p=probabilities, size=shape).astype(dtype) - return sample + raise NotImplementedError( + "`binomial` is not supported with openvino backend" + ) def beta(shape, alpha, beta, dtype=None, seed=None): - dtype = dtype or floatx() - seed = draw_seed(seed) - rng = np.random.default_rng(seed) - sample = rng.beta(a=alpha, b=beta, size=shape).astype(dtype) - return sample + raise NotImplementedError( + "`beta` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py index 07f65752514..41fba1ffc08 100644 --- a/keras/src/backend/openvino/rnn.py +++ b/keras/src/backend/openvino/rnn.py @@ -1,238 +1,43 @@ -import numpy as np - -from keras.src import tree - - def rnn( - step_function, - inputs, - initial_states, - go_backwards=False, - mask=None, - constants=None, - unroll=False, - input_length=None, - time_major=False, - zero_output_for_mask=False, - return_all_outputs=True, + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, ): - def swap_batch_timestep(input_t): - # Swap the batch and timestep dim for the incoming tensor. - axes = list(range(len(input_t.shape))) - axes[0], axes[1] = 1, 0 - return np.transpose(input_t, axes) - - if not time_major: - inputs = tree.map_structure(swap_batch_timestep, inputs) - - flattened_inputs = tree.flatten(inputs) - time_steps = flattened_inputs[0].shape[0] - - if mask is not None: - if mask.dtype != "bool": - mask = mask.astype("bool") - if len(mask.shape) == 2: - mask = np.expand_dims(mask, axis=-1) - if not time_major: - mask = swap_batch_timestep(mask) - - if constants is None: - constants = [] - - def _expand_mask(mask_t, input_t, fixed_dim=1): - if tree.is_nested(mask_t): - raise ValueError( - f"mask_t is expected to be tensor, but got {mask_t}" - ) - if tree.is_nested(input_t): - raise ValueError( - f"input_t is expected to be tensor, but got {input_t}" - ) - rank_diff = len(input_t.shape) - len(mask_t.shape) - for _ in range(rank_diff): - mask_t = np.expand_dims(mask_t, -1) - multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) - return np.tile(mask_t, multiples) - - if unroll: - if not time_steps: - raise ValueError("Unrolling requires a fixed number of timesteps.") - states = tuple(initial_states) - successive_states = [] - successive_outputs = [] - - # Process the input tensors. The input tensor need to be split on the - # time_step dim, and reverse if go_backwards is True. In the case of - # nested input, the input is flattened and then transformed - # individually. The result of this will be a tuple of lists, each of - # the item in tuple is list of the tensor with shape (batch, feature) - def _process_single_input_t(input_t): - input_t = unstack(input_t) # unstack for time_step dim - if go_backwards: - input_t.reverse() - return input_t - - if tree.is_nested(inputs): - processed_input = tree.map_structure( - _process_single_input_t, inputs - ) - else: - processed_input = (_process_single_input_t(inputs),) - - def _get_input_tensor(time): - inp = [t_[time] for t_ in processed_input] - return tree.pack_sequence_as(inputs, inp) - - if mask is not None: - mask_list = unstack(mask) - if go_backwards: - mask_list.reverse() - - for i in range(time_steps): - inp = _get_input_tensor(i) - mask_t = mask_list[i] - output, new_states = step_function( - inp, tuple(states) + tuple(constants) - ) - tiled_mask_t = _expand_mask(mask_t, output) - - if not successive_outputs: - prev_output = np.zeros_like(output) - else: - prev_output = successive_outputs[-1] - - output = np.where(tiled_mask_t, output, prev_output) - - flat_states = tree.flatten(states) - flat_new_states = tree.flatten(new_states) - tiled_mask_t = tuple( - _expand_mask(mask_t, s) for s in flat_states - ) - flat_final_states = tuple( - np.where(m, s, ps) - for m, s, ps in zip( - tiled_mask_t, flat_new_states, flat_states - ) - ) - states = tree.pack_sequence_as(states, flat_final_states) - - if return_all_outputs: - successive_outputs.append(output) - successive_states.append(states) - else: - successive_outputs = [output] - successive_states = [states] - last_output = successive_outputs[-1] - new_states = successive_states[-1] - outputs = np.stack(successive_outputs) - - else: # mask is None - for i in range(time_steps): - inp = _get_input_tensor(i) - output, states = step_function( - inp, tuple(states) + tuple(constants) - ) - if return_all_outputs: - successive_outputs.append(output) - successive_states.append(states) - else: - successive_outputs = [output] - successive_states = [states] - last_output = successive_outputs[-1] - new_states = successive_states[-1] - outputs = np.stack(successive_outputs) - - else: # Unroll == False - if mask is not None: - - def _step(states, current_input): - current_input, current_mask = current_input - is_masked = np.all( - np.logical_not(current_mask), axis=-1, keepdims=True - ) - - output_t, new_states = step_function(current_input, states) - - if zero_output_for_mask: - masked_outs = np.where( - is_masked, np.zeros_like(output_t), output_t - ) - else: - # Assume the first state is the previous output. - output_tm1 = states[0] - masked_outs = np.where(is_masked, output_tm1, output_t) - - new_states = [ - np.where(is_masked, s, ns) - for s, ns in zip(states, new_states) - ] - return (new_states, masked_outs) - - scan_xs = (inputs, mask) - - else: - - def _step(states, current_input): - output_t, new_states = step_function(current_input, states) - return new_states, output_t - - scan_xs = inputs - - new_states, outputs = numpy_scan( - f=_step, - init=initial_states, - xs=scan_xs, - reverse=go_backwards, - mask=mask, - ) - - if go_backwards: - outputs = np.flip(outputs, axis=0) - last_output = outputs[-1] - - if not time_major: - outputs = tree.map_structure(swap_batch_timestep, outputs) - - return last_output, outputs, new_states + raise NotImplementedError( + "`rnn` is not supported with openvino backend" + ) def lstm(*args, **kwargs): - raise NotImplementedError + raise NotImplementedError( + "`lstm` is not supported with openvino backend" + ) def gru(*args, **kwargs): - raise NotImplementedError + raise NotImplementedError( + "`gru` is not supported with openvino backend" + ) def unstack(x, axis=0): - return [x.take(i, axis) for i in range(x.shape[axis])] + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) def numpy_scan(f, init, xs, reverse=False, mask=None): - states = init - outputs = [] - - if mask is not None: - x, mask = xs - x = np.flip(x, axis=0) if reverse else x - mask = np.flip(mask, axis=0) if reverse else mask - - for each_x, each_mask in zip(x, mask): - states, output = f(states, (each_x, each_mask)) - outputs.append(output) - else: - xs = np.flip(xs, axis=0) if reverse else xs - - for x in xs: - states, output = f(states, x) - outputs.append(output) - - outputs = np.array(outputs) - - if reverse: - outputs = np.flip(outputs, axis=0) - - return states, outputs + raise NotImplementedError( + "`numpy_scan` is not supported with openvino backend" + ) def cudnn_ok(*args, **kwargs): From 9e29a73cafc9cf9231a61e1fa20f399ea740bdc9 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 10:40:47 +0400 Subject: [PATCH 05/19] Fix sorting imports Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index f6d91af9456..175e945a60d 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,5 +1,5 @@ from keras.src.backend.common import dtypes -from keras.src.backend.openvino.core import convert_to_tensor, ov_to_keras_type, OPENVINO_DTYPES +from keras.src.backend.openvino.core import OPENVINO_DTYPES, convert_to_tensor, ov_to_keras_type def _align_operand_types(x1, x2, op_name): From 716a0bb93315b4fd1cc331fdc7f271f3e1eaf028 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 10:45:56 +0400 Subject: [PATCH 06/19] Format imports Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/numpy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 175e945a60d..eabbeb64206 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,5 +1,7 @@ from keras.src.backend.common import dtypes -from keras.src.backend.openvino.core import OPENVINO_DTYPES, convert_to_tensor, ov_to_keras_type +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import ov_to_keras_type def _align_operand_types(x1, x2, op_name): From c09712668f4daa03d5037504859ea505c4f3f0cd Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 10:55:43 +0400 Subject: [PATCH 07/19] Fix sorting imports Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 13 +++++++------ keras/src/backend/openvino/linalg.py | 4 ---- keras/src/backend/openvino/nn.py | 3 ++- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index ea4cec6001a..6d49aa56b96 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -1,15 +1,15 @@ import contextlib import numpy as np +import openvino as ov -from keras.src.backend.common import global_state from keras.src import tree from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state from keras.src.backend.common import standardize_dtype from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope -import openvino as ov SUPPORTS_SPARSE_TENSORS = False @@ -39,6 +39,7 @@ def ov_to_keras_type(ov_type): f"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'" ) + @contextlib.contextmanager def device_scope(device_name): current_device = _parse_device_input(device_name) @@ -226,10 +227,10 @@ def slice_update(inputs, start_indices, updates): def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, + cond, + body, + loop_vars, + maximum_iterations=None, ): raise NotImplementedError( "`while_loop` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 9675e9fa4be..5c6c65cf86e 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,7 +1,3 @@ -import numpy as np -from keras.src.backend.openvino.core import convert_to_tensor - - def cholesky(a): raise NotImplementedError( "`cholesky` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 26f9b59de2a..3362a9d4b9f 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -33,7 +33,8 @@ def softsign(x): def silu(x): - from openvino.runtime.opset14 import sigmoid, multiply + from openvino.runtime.opset14 import sigmoid + from openvino.runtime.opset14 import multiply return multiply(x, sigmoid(x)) From 65fe256da19f316d76f095ca908f1c2ab0d5e86b Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 11:05:10 +0400 Subject: [PATCH 08/19] Fix sorting imports Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/nn.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 3362a9d4b9f..d9be1ede4a0 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -1,8 +1,19 @@ import numpy as np +from openvino.runtime.opset14 import elu +from openvino.runtime.opset14 import gelu +from openvino.runtime.opset14 import hard_sigmoid +from openvino.runtime.opset14 import log_softmax +from openvino.runtime.opset14 import multiply +from openvino.runtime.opset14 import relu +from openvino.runtime.opset14 import selu +from openvino.runtime.opset14 import sigmoid +from openvino.runtime.opset14 import softmax +from openvino.runtime.opset14 import softplus +from openvino.runtime.opset14 import softsign +from openvino.runtime.opset14 import tanh def relu(x): - from openvino.runtime.opset14 import relu return relu(x) @@ -13,28 +24,22 @@ def relu6(x): def sigmoid(x): - from openvino.runtime.opset14 import sigmoid return sigmoid(x) def tanh(x): - from openvino.runtime.opset14 import tanh return tanh(x) def softplus(x): - from openvino.runtime.opset14 import softplus return softplus(x) def softsign(x): - from openvino.runtime.opset14 import softsign return softsign(x) def silu(x): - from openvino.runtime.opset14 import sigmoid - from openvino.runtime.opset14 import multiply return multiply(x, sigmoid(x)) @@ -51,19 +56,16 @@ def leaky_relu(x, negative_slope=0.2): def hard_sigmoid(x): - from openvino.runtime.opset14 import hard_sigmoid alpha = 1 / np.array(6.0, dtype=np.float32) beta = np.array(0.5, dtype=np.float32) return hard_sigmoid(x, alpha, beta) def hard_silu(x): - from openvino.runtime.opset14 import multiply return multiply(x, hard_sigmoid(x)) def elu(x, alpha=1.0): - from openvino.runtime.opset14 import elu return elu(x, alpha) @@ -72,12 +74,10 @@ def selu( alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946, ): - from openvino.runtime.opset14 import selu return selu(x, alpha, scale) def gelu(x, approximate=True): - from openvino.runtime.opset14 import gelu approximate_mode = "erf" if approximate: approximate_mode = "tanh" @@ -85,12 +85,10 @@ def gelu(x, approximate=True): def softmax(x, axis=None): - from openvino.runtime.opset14 import softmax return softmax(x, axis) def log_softmax(x, axis=None): - from openvino.runtime.opset14 import log_softmax return log_softmax(x, axis) From fe29eb9b0a071867e5ef8e95b5e2e6b0995eed0d Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 14:21:45 +0400 Subject: [PATCH 09/19] Fix inference Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/nn.py | 39 ++++++++++------------------ keras/src/backend/openvino/random.py | 3 +++ keras/src/layers/layer.py | 2 +- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index d9be1ede4a0..3b75d4476e4 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -1,20 +1,9 @@ import numpy as np -from openvino.runtime.opset14 import elu -from openvino.runtime.opset14 import gelu -from openvino.runtime.opset14 import hard_sigmoid -from openvino.runtime.opset14 import log_softmax -from openvino.runtime.opset14 import multiply -from openvino.runtime.opset14 import relu -from openvino.runtime.opset14 import selu -from openvino.runtime.opset14 import sigmoid -from openvino.runtime.opset14 import softmax -from openvino.runtime.opset14 import softplus -from openvino.runtime.opset14 import softsign -from openvino.runtime.opset14 import tanh +from openvino.runtime import opset14 def relu(x): - return relu(x) + return opset14.relu(x) def relu6(x): @@ -24,23 +13,23 @@ def relu6(x): def sigmoid(x): - return sigmoid(x) + return opset14.sigmoid(x) def tanh(x): - return tanh(x) + return opset14.tanh(x) def softplus(x): - return softplus(x) + return opset14.softplus(x) def softsign(x): - return softsign(x) + return opset14.softsign(x) def silu(x): - return multiply(x, sigmoid(x)) + return opset14.multiply(x, opset14.sigmoid(x)) def log_sigmoid(x): @@ -58,15 +47,15 @@ def leaky_relu(x, negative_slope=0.2): def hard_sigmoid(x): alpha = 1 / np.array(6.0, dtype=np.float32) beta = np.array(0.5, dtype=np.float32) - return hard_sigmoid(x, alpha, beta) + return opset14.hard_sigmoid(x, alpha, beta) def hard_silu(x): - return multiply(x, hard_sigmoid(x)) + return opset14.multiply(x, hard_sigmoid(x)) def elu(x, alpha=1.0): - return elu(x, alpha) + return opset14.elu(x, alpha) def selu( @@ -74,22 +63,22 @@ def selu( alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946, ): - return selu(x, alpha, scale) + return opset14.selu(x, alpha, scale) def gelu(x, approximate=True): approximate_mode = "erf" if approximate: approximate_mode = "tanh" - return gelu(x, approximate_mode) + return opset14.gelu(x, approximate_mode) def softmax(x, axis=None): - return softmax(x, axis) + return opset14.softmax(x, axis) def log_softmax(x, axis=None): - return log_softmax(x, axis) + return opset14.log_softmax(x, axis) def max_pool( diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index 18c70757578..ed902e34ad1 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -1,3 +1,6 @@ +from keras.src.random.seed_generator import SeedGenerator + + def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): raise NotImplementedError( "`normal` is not supported with openvino backend" diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 6d01f06ef73..3756b79e093 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -53,7 +53,7 @@ elif backend.backend() == "numpy": from keras.src.backend.numpy.layer import NumpyLayer as BackendLayer elif backend.backend() == "openvino": - from keras.src.backend.openvino.layer import NumpyLayer as BackendLayer + from keras.src.backend.openvino.layer import OpenvinoLayer as BackendLayer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement a layer mixin class." From f57847bc31ceed490b40467dfead48fec7e6de87 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 20:52:43 +0400 Subject: [PATCH 10/19] Remove openvino specific code in common part Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/trainer.py | 163 ++++++++++++-------------- keras/src/models/functional.py | 10 -- keras/src/ops/function.py | 35 ------ keras/src/utils/backend_utils.py | 6 +- 4 files changed, 80 insertions(+), 134 deletions(-) diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index 5c510d6a958..c6d8575b002 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -1,10 +1,14 @@ import numpy as np +import openvino as ov +import openvino.runtime.opset14 as ov_opset from keras.src import backend from keras.src import callbacks as callbacks_module from keras.src import tree from keras.src.backend.common import standardize_dtype from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import get_device from keras.src.backend.openvino.core import is_tensor from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import data_adapter_utils @@ -19,29 +23,15 @@ def __init__(self): self.predict_function = None def test_step(self, data): - ( - x, - y, - sample_weight, - ) = data_adapter_utils.unpack_x_y_sample_weight(data) - if self._call_has_training_arg: - y_pred = self(x, training=False) - else: - y_pred = self(x) - loss = self.compute_loss( - x=x, y=y, y_pred=y_pred, sample_weight=sample_weight - ) - self._loss_tracker.update_state( - loss, sample_weight=tree.flatten(x)[0].shape[0] - ) + (x, y, sample_weight,) = data_adapter_utils.unpack_x_y_sample_weight(data) + y_pred = self(x) + loss = self.compute_loss(x=x, y=y, y_pred=y_pred, sample_weight=sample_weight) + self._loss_tracker.update_state(loss, sample_weight=tree.flatten(x)[0].shape[0]) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) def predict_step(self, data): x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) - if self._call_has_training_arg: - y_pred = self(x, training=False) - else: - y_pred = self(x) + y_pred = self(x) return y_pred def make_test_function(self, force=False): @@ -68,34 +58,28 @@ def make_predict_function(self, force=False): if self.predict_function is not None and not force: return self.predict_function - def one_predict_step(data): - data = data[0] - return self.predict_step(data) - - def multi_predict_steps(data): - outputs = one_predict_step(data[:1]) - - for single_step_data in data[1:]: - step_outputs = one_predict_step([single_step_data]) - outputs = tree.map_structure( - lambda t1, t2: np.concatenate([t1, t2]), - outputs, - step_outputs, - ) - return outputs - - if self.steps_per_execution > 1: - predict_step = multi_predict_steps - else: - predict_step = one_predict_step - - self.predict_function = predict_step + ov_inputs = [] + for _input in self._inputs: + ov_type = OPENVINO_DTYPES[_input.dtype] + ov_shape = _input.shape + ov_shape = list(ov_shape) + for i in range(len(ov_shape)): + if ov_shape[i] is None: + ov_shape[i] = -1 + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + ov_inputs.append(param) + # build OpenVINO graph ov.Model + ov_outputs = self._run_through_graph(ov_inputs, operation_fn=lambda op: op) + ov_outputs = tree.flatten(ov_outputs) + ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) + self.predict_function = ov.compile_model(ov_model, get_device()) + return self.predict_function def _symbolic_build(self, data_batch): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) compile_metrics_unbuilt = ( - self._compile_metrics is not None - and not self._compile_metrics.built + self._compile_metrics is not None + and not self._compile_metrics.built ) if model_unbuilt or compile_metrics_unbuilt: # Create symbolic tensors matching an input batch. @@ -136,29 +120,31 @@ def to_symbolic_input(v): self._post_build() def fit( - self, - x=None, - y=None, - batch_size=None, - epochs=1, - verbose="auto", - callbacks=None, - validation_split=0.0, - validation_data=None, - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - validation_batch_size=None, - validation_freq=1, + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, ): - raise NotImplementedError("fit not implemented for OpenVINO backend.") + raise NotImplementedError( + "fit not supported with openvino backend." + ) @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -196,13 +182,20 @@ def append_to_outputs(batch_outputs, outputs): ) return outputs + def unpack_singleton(x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x + self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() outputs = None for step, data in epoch_iterator.enumerate_epoch(): callbacks.on_predict_batch_begin(step) - batch_outputs = self.predict_function(data) + flat_inputs = tree.flatten(data) + batch_outputs = self.predict_function(flat_inputs) + batch_outputs = unpack_singleton(tree.pack_sequence_as(self._outputs_struct, batch_outputs.to_tuple())) outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) if self.stop_predicting: @@ -212,16 +205,16 @@ def append_to_outputs(batch_outputs, outputs): @traceback_utils.filter_traceback def evaluate( - self, - x=None, - y=None, - batch_size=None, - verbose="auto", - sample_weight=None, - steps=None, - callbacks=None, - return_dict=False, - **kwargs, + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, ): # TODO: respect compiled trainable state use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) @@ -281,23 +274,23 @@ def evaluate( return self._flatten_metrics_in_order(logs) def train_on_batch( - self, - x, - y=None, - sample_weight=None, - class_weight=None, - return_dict=False, + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, ): raise NotImplementedError( - "train_on_batch not implemented for OpenVINO backend." + "train_on_batch not supported with openvino backend." ) def test_on_batch( - self, - x, - y=None, - sample_weight=None, - return_dict=False, + self, + x, + y=None, + sample_weight=None, + return_dict=False, ): self._assert_compile_called("test_on_batch") diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 9d52da1d8ad..f12227a7604 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -170,18 +170,8 @@ def layers(self, _): ) def call(self, inputs, training=None, mask=None): - from keras.src.backend.config import backend # Add support for traning, masking inputs = self._standardize_inputs(inputs) - if backend() == "openvino": - from keras.src.backend.openvino.core import get_device - if self._ov_device != get_device(): - # update the current device and re-compile a model - self._ov_device = get_device() - self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._openvino_device) - outputs = self._ov_compiled_model(inputs) - return unpack_singleton(tree.pack_sequence_as(self._outputs_struct, outputs.to_tuple())) - if mask is None: masks = [None] * len(inputs) else: diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 0f1cd4c376d..3e5daf035b0 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -81,32 +81,6 @@ def __init__(self, inputs, outputs, name=None): self._nodes_by_depth = nodes_by_depth self._operations = operations self._operations_by_depth = operations_by_depth - if backend() == "openvino": - from keras.src.backend.openvino.core import OPENVINO_DTYPES - from keras.src.backend.openvino.core import get_device - import openvino as ov - import openvino.runtime.opset14 as ov_opset - from openvino import Core - # prepare OpenVINO parameters - ov_inputs = [] - for _input in self._inputs: - ov_type = OPENVINO_DTYPES[_input.dtype] - ov_shape = _input.shape - ov_shape = list(ov_shape) - for i in range(len(ov_shape)): - if ov_shape[i] is None: - ov_shape[i] = -1 - param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) - ov_inputs.append(param) - pass - # build OpenVINO graph - ov.Model - ov_outputs = self._run_through_graph(ov_inputs, operation_fn=lambda op: op) - ov_outputs = tree.flatten(ov_outputs) - self._ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) - self._ov_core = Core() - self._ov_device = get_device() - self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._ov_device) - pass @property def operations(self): @@ -161,15 +135,6 @@ def compute_output_shape(self, input_shape): def call(self, inputs): """Computes output tensors for new inputs.""" self._assert_input_compatibility(inputs) - if backend() == "openvino": - from keras.src.backend.openvino.core import get_device - if self._ov_device != get_device(): - # update the current device and re-compile a model - self._ov_device = get_device() - self._ov_compiled_model = self._ov_core.compile_model(self._ov_model, self._openvino_device) - inputs = tree.flatten(inputs) - outputs = self._ov_compiled_model(inputs) - return tree.pack_sequence_as(self._outputs_struct, outputs.to_tuple()) return self._run_through_graph(inputs, operation_fn=lambda op: op) def _run_through_graph(self, inputs, operation_fn, call_fn=None): diff --git a/keras/src/utils/backend_utils.py b/keras/src/utils/backend_utils.py index a88f5c79999..cd07121bf35 100644 --- a/keras/src/utils/backend_utils.py +++ b/keras/src/utils/backend_utils.py @@ -92,11 +92,9 @@ def __getattr__(self, name): "Currently, we cannot dynamically import the numpy backend " "because it would disrupt the namespace of the import." ) - if self._backend == "openvino": - from keras.src.backend import openvino as openvino_backend - - return getattr(openvino_backend, name) + module = importlib.import_module("keras.src.backend.openvino") + return getattr(module, name) @keras_export("keras.config.set_backend") def set_backend(backend): From 8fb2dc5d6e8b98c1cdb11b2fb9d2c8cd3f18eefb Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Tue, 24 Sep 2024 20:54:31 +0400 Subject: [PATCH 11/19] Fix typo --- keras/src/models/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index f12227a7604..70be97697fd 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -170,7 +170,7 @@ def layers(self, _): ) def call(self, inputs, training=None, mask=None): - # Add support for traning, masking + # Add support for training, masking inputs = self._standardize_inputs(inputs) if mask is None: masks = [None] * len(inputs) From 3d8b41a0b656c70647470f83bbad192cca32dc1d Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 09:44:06 +0400 Subject: [PATCH 12/19] Clean-up code Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 24 ++-- keras/src/backend/openvino/numpy.py | 34 ++--- keras/src/backend/openvino/trainer.py | 187 ++++++++------------------ 3 files changed, 82 insertions(+), 163 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 6d49aa56b96..3aca7c6806f 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -71,17 +71,25 @@ def _parse_device_input(device_name): class Variable(KerasVariable): def _initialize(self, value): - self._value = np.array(value, dtype=self._dtype) + raise NotImplementedError( + "`Variable._initialize` is not supported with openvino backend" + ) def _direct_assign(self, value): - self._value = np.array(value, dtype=self._dtype) + raise NotImplementedError( + "`Variable._direct_assign` is not supported with openvino backend" + ) def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype) + raise NotImplementedError( + "`Variable._convert_to_tensor` is not supported with openvino backend" + ) # Overload native accessor. def __array__(self): - return self.value + raise NotImplementedError( + "`Variable.__array__` is not supported with openvino backend" + ) def convert_to_tensor(x, dtype=None, sparse=None): @@ -94,14 +102,12 @@ def convert_to_tensor(x, dtype=None, sparse=None): return x.value.astype(dtype) return x.value if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16": - # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset). - # Instead we convert "as is" (to stored dtype) and cast. - return np.asarray(x).astype(dtype) + return ov.Tensor(np.asarray(x).astype(dtype)) if dtype is None: dtype = result_type( *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] ) - return np.array(x, dtype=dtype) + return ov.Tensor(np.array(x, dtype=dtype)) def convert_to_numpy(x): @@ -109,7 +115,7 @@ def convert_to_numpy(x): def is_tensor(x): - if isinstance(x, (np.generic, np.ndarray)): + if isinstance(x, ov.Tensor): return True return False diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index eabbeb64206..a08a75fd332 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,3 +1,5 @@ +from openvino.runtime import opset14 + from keras.src.backend.common import dtypes from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import convert_to_tensor @@ -5,7 +7,6 @@ def _align_operand_types(x1, x2, op_name): - from openvino.runtime.opset14 import convert x1_type = x1.element_type x2_type = x2.element_type if x1_type.is_dynamic() or x2_type.is_dynamic(): @@ -17,16 +18,15 @@ def _align_operand_types(x1, x2, op_name): result_type = dtypes.result_type(x1_type, x2_type) result_type = OPENVINO_DTYPES[result_type] if x1_type != result_type: - x1 = convert(x1, result_type) + x1 = opset14.convert(x1, result_type) if x2_type != result_type: - x2 = convert(x2, result_type) + x2 = opset14.convert(x2, result_type) return x1, x2 def add(x1, x2): - from openvino.runtime.opset14 import add x1, x2 = _align_operand_types(x1, x2, "add()") - return add(x1, x2) + return opset14.add(x1, x2) def einsum(subscripts, *operands, **kwargs): @@ -36,9 +36,8 @@ def einsum(subscripts, *operands, **kwargs): def subtract(x1, x2): - from openvino.runtime.opset14 import subtract x1, x2 = _align_operand_types(x1, x2, "subtract()") - return subtract(x1, x2) + return opset14.subtract(x1, x2) def matmul(x1, x2): @@ -48,9 +47,8 @@ def matmul(x1, x2): def multiply(x1, x2): - from openvino.runtime.opset14 import multiply x1, x2 = _align_operand_types(x1, x2, "multiply()") - return multiply(x1, x2) + return opset14.multiply(x1, x2) def mean(x, axis=None, keepdims=False): @@ -78,13 +76,11 @@ def zeros(shape, dtype=None): def absolute(x): - from openvino.runtime.opset14 import absolute - return absolute(x) + return opset14.absolute(x) def abs(x): - from openvino.runtime.opset14 import absolute - return absolute(x) + return opset14.absolute(x) def all(x, axis=None, keepdims=False): @@ -772,9 +768,8 @@ def where(condition, x1, x2): def divide(x1, x2): - from openvino.runtime.opset14 import divide x1, x2 = _align_operand_types(x1, x2) - return divide(x1, x2) + return opset14.divide(x1, x2) def divide_no_nan(x1, x2): @@ -788,14 +783,12 @@ def true_divide(x1, x2): def power(x1, x2): - from openvino.runtime.opset14 import power x1, x2 = _align_operand_types(x1, x2) - return power(x1, x2) + return opset14.power(x1, x2) def negative(x): - from openvino.runtime.opset14 import negative - return negative(x) + return opset14.negative(x) def square(x): @@ -847,8 +840,7 @@ def floor_divide(x1, x2): def logical_xor(x1, x2): - from openvino.runtime.opset14 import logical_xor - return logical_xor(x1, x2) + return opset14.logical_xor(x1, x2) def correlate(x1, x2, mode="valid"): diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index c6d8575b002..cebf5a0fe0c 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -5,11 +5,8 @@ from keras.src import backend from keras.src import callbacks as callbacks_module from keras.src import tree -from keras.src.backend.common import standardize_dtype -from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import get_device -from keras.src.backend.openvino.core import is_tensor from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.epoch_iterator import EpochIterator @@ -21,17 +18,25 @@ def __init__(self): super().__init__() self.test_function = None self.predict_function = None + self.ov_compiled_model = None + self.ov_device = None + + def _unpack_singleton(self, x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x def test_step(self, data): - (x, y, sample_weight,) = data_adapter_utils.unpack_x_y_sample_weight(data) - y_pred = self(x) - loss = self.compute_loss(x=x, y=y, y_pred=y_pred, sample_weight=sample_weight) - self._loss_tracker.update_state(loss, sample_weight=tree.flatten(x)[0].shape[0]) - return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + raise NotImplementedError( + "`test_step` is not supported with openvino backend" + ) def predict_step(self, data): x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) - y_pred = self(x) + ov_compiled_model = self._get_compiled_model() + flatten_x = tree.flatten(x) + y_pred = ov_compiled_model(flatten_x) + y_pred = self._unpack_singleton(tree.pack_sequence_as(self._outputs_struct, y_pred.to_tuple())) return y_pred def make_test_function(self, force=False): @@ -54,10 +59,12 @@ def multi_test_steps(data): self.test_function = test_step - def make_predict_function(self, force=False): - if self.predict_function is not None and not force: - return self.predict_function + def _get_compiled_model(self): + if self.ov_compiled_model is not None and get_device() == self.ov_device: + return self.ov_compiled_model + # prepare compiled model from scratch + del self.ov_compiled_model ov_inputs = [] for _input in self._inputs: ov_type = OPENVINO_DTYPES[_input.dtype] @@ -72,52 +79,36 @@ def make_predict_function(self, force=False): ov_outputs = self._run_through_graph(ov_inputs, operation_fn=lambda op: op) ov_outputs = tree.flatten(ov_outputs) ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) - self.predict_function = ov.compile_model(ov_model, get_device()) - return self.predict_function + self.ov_compiled_model = ov.compile_model(ov_model, get_device()) + self.ov_device = get_device() + return self.ov_compiled_model - def _symbolic_build(self, data_batch): - model_unbuilt = not all(layer.built for layer in self._flatten_layers()) - compile_metrics_unbuilt = ( - self._compile_metrics is not None - and not self._compile_metrics.built - ) - if model_unbuilt or compile_metrics_unbuilt: - # Create symbolic tensors matching an input batch. + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function - def to_symbolic_input(v): - if is_tensor(v): - return KerasTensor(v.shape, standardize_dtype(v.dtype)) - return v + def one_predict_step(data): + data = data[0] + return self.predict_step(data) - data_batch = tree.map_structure(to_symbolic_input, data_batch) - ( - x, - y, - sample_weight, - ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) - # Build all model state with `backend.compute_output_spec`. - try: - y_pred = backend.compute_output_spec(self, x) - except: - raise RuntimeError( - "Unable to automatically build the model. " - "Please build it yourself before calling " - "fit/evaluate/predict. " - "A model is 'built' when its variables have " - "been created and its `self.built` attribute " - "is True. Usually, calling the model on a batch " - "of data is the right way to build it." - ) - if compile_metrics_unbuilt: - # Build all metric state with `backend.compute_output_spec`. - backend.compute_output_spec( - self.compute_metrics, - x, - y, - y_pred, - sample_weight=sample_weight, + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, ) - self._post_build() + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step def fit( self, @@ -139,7 +130,7 @@ def fit( validation_freq=1, ): raise NotImplementedError( - "fit not supported with openvino backend." + "`fit` is not supported with openvino backend" ) @traceback_utils.filter_traceback @@ -182,20 +173,13 @@ def append_to_outputs(batch_outputs, outputs): ) return outputs - def unpack_singleton(x): - if isinstance(x, (list, tuple)) and len(x) == 1: - return x[0] - return x - self.make_predict_function() self.stop_predicting = False callbacks.on_predict_begin() outputs = None for step, data in epoch_iterator.enumerate_epoch(): callbacks.on_predict_batch_begin(step) - flat_inputs = tree.flatten(data) - batch_outputs = self.predict_function(flat_inputs) - batch_outputs = unpack_singleton(tree.pack_sequence_as(self._outputs_struct, batch_outputs.to_tuple())) + batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) if self.stop_predicting: @@ -216,62 +200,9 @@ def evaluate( return_dict=False, **kwargs, ): - # TODO: respect compiled trainable state - use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) - if kwargs: - raise ValueError(f"Arguments not recognized: {kwargs}") - - if use_cached_eval_dataset: - epoch_iterator = self._eval_epoch_iterator - else: - # Create an iterator that yields batches of input/target data. - epoch_iterator = EpochIterator( - x=x, - y=y, - sample_weight=sample_weight, - batch_size=batch_size, - steps_per_epoch=steps, - shuffle=False, - steps_per_execution=self.steps_per_execution, - ) - - if not all(layer.built for layer in self._flatten_layers()): - # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(): - data_batch = data[0] - self._symbolic_build(data_batch) - break - - # Container that configures and calls callbacks. - if not isinstance(callbacks, callbacks_module.CallbackList): - callbacks = callbacks_module.CallbackList( - callbacks, - add_history=True, - add_progbar=verbose != 0, - verbose=verbose, - epochs=1, - steps=epoch_iterator.num_batches, - model=self, - ) - - self.make_test_function() - self.stop_evaluating = False - callbacks.on_test_begin() - logs = None - self.reset_metrics() - for step, data in epoch_iterator.enumerate_epoch(): - callbacks.on_test_batch_begin(step) - logs = self.test_function(data) - logs = self._pythonify_logs(logs) - callbacks.on_test_batch_end(step, logs) - if self.stop_evaluating: - break - logs = self._get_metrics_result_or_logs(logs) - callbacks.on_test_end(logs) - - if return_dict: - return logs - return self._flatten_metrics_in_order(logs) + raise NotImplementedError( + "`evaluate` is not supported with openvino backend" + ) def train_on_batch( self, @@ -282,7 +213,7 @@ def train_on_batch( return_dict=False, ): raise NotImplementedError( - "train_on_batch not supported with openvino backend." + "`train_on_batch` is not supported with openvino backend" ) def test_on_batch( @@ -292,19 +223,9 @@ def test_on_batch( sample_weight=None, return_dict=False, ): - self._assert_compile_called("test_on_batch") - - data = (x, y, sample_weight) - - # Maybe build model - self._symbolic_build(data) - self.make_test_function() - - logs = self.test_function([data]) - logs = tree.map_structure(lambda x: np.array(x), logs) - if return_dict: - return logs - return self._flatten_metrics_in_order(logs) + raise NotImplementedError( + "`test_on_batch` is not supported with openvino backend" + ) def predict_on_batch(self, x): self.make_predict_function() From 08b74dc1c6fe4da6dacc05700a03616518524a33 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 14:57:02 +0400 Subject: [PATCH 13/19] Recover imports Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/__init__.py | 1 + keras/src/backend/openvino/core.py | 8 +++++++- keras/src/backend/openvino/image.py | 4 ++++ keras/src/backend/openvino/linalg.py | 6 ++++++ keras/src/backend/openvino/math.py | 5 +++++ keras/src/backend/openvino/nn.py | 3 +++ keras/src/backend/openvino/numpy.py | 4 ++++ keras/src/backend/openvino/random.py | 3 +++ keras/src/backend/openvino/rnn.py | 3 +++ keras/src/backend/openvino/trainer.py | 2 ++ 10 files changed, 38 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py index c54a33c809a..8ed2ad1f594 100644 --- a/keras/src/backend/openvino/__init__.py +++ b/keras/src/backend/openvino/__init__.py @@ -13,6 +13,7 @@ from keras.src.backend.openvino.core import convert_to_numpy from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import is_tensor +from keras.src.backend.openvino.core import random_seed_dtype from keras.src.backend.openvino.core import shape from keras.src.backend.openvino.core import vectorized_map from keras.src.backend.openvino.rnn import cudnn_ok diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 3aca7c6806f..8cb3eb358e8 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -4,12 +4,14 @@ import openvino as ov from keras.src import tree -from keras.src.backend.common import KerasVariable from keras.src.backend.common import global_state +from keras.src.backend.common import KerasVariable from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.backend_utils import slice_along_axis from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope SUPPORTS_SPARSE_TENSORS = False @@ -259,6 +261,10 @@ def unstack(x, num=None, axis=0): ) +def random_seed_dtype(): + return "uint32" + + def custom_gradient(fun): raise NotImplementedError( "`custom_gradient` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py index 86c59e51231..fa76f89b35c 100644 --- a/keras/src/backend/openvino/image.py +++ b/keras/src/backend/openvino/image.py @@ -1,3 +1,7 @@ +from keras.src import backend +from keras.src.utils.module_utils import scipy + + def rgb_to_grayscale(image, data_format="channels_last"): raise NotImplementedError( "`rgb_to_grayscale` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 5c6c65cf86e..1fb1e6400b7 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,3 +1,9 @@ +import scipy.linalg as sl + +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes + + def cholesky(a): raise NotImplementedError( "`cholesky` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py index 23cf17bb3a6..d9aaefb86e8 100644 --- a/keras/src/backend/openvino/math.py +++ b/keras/src/backend/openvino/math.py @@ -1,3 +1,8 @@ +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.utils.module_utils import scipy + + def segment_sum(data, segment_ids, num_segments=None, sorted=False): raise NotImplementedError( "`segment_sum` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 3b75d4476e4..ac9ac7ac31d 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -1,6 +1,9 @@ import numpy as np from openvino.runtime import opset14 +from keras.src import backend +from keras.src.utils.module_utils import scipy + def relu(x): return opset14.relu(x) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index a08a75fd332..b987e9326f6 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,6 +1,10 @@ from openvino.runtime import opset14 +from keras.src import tree +from keras.src.backend import config +from keras.src.backend import standardize_dtype from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import standardize_axis_for_numpy from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import ov_to_keras_type diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index ed902e34ad1..c8a4762b24e 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -1,4 +1,7 @@ +from keras.src.backend.config import floatx from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py index 41fba1ffc08..d1f9abc6761 100644 --- a/keras/src/backend/openvino/rnn.py +++ b/keras/src/backend/openvino/rnn.py @@ -1,3 +1,6 @@ +from keras.src import tree + + def rnn( step_function, inputs, diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index cebf5a0fe0c..c06988b7ff8 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -5,6 +5,8 @@ from keras.src import backend from keras.src import callbacks as callbacks_module from keras.src import tree +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import get_device from keras.src.trainers import trainer as base_trainer From 52408b70d5ac1f04c1f89b239e3a1ec9751b7200 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 19:40:00 +0400 Subject: [PATCH 14/19] Sort imports properly Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 8cb3eb358e8..ca4ceb7217f 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -4,8 +4,8 @@ import openvino as ov from keras.src import tree -from keras.src.backend.common import global_state from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state from keras.src.backend.common import standardize_dtype from keras.src.backend.common.backend_utils import slice_along_axis from keras.src.backend.common.dtypes import result_type From f41b4543f55b21cebbd31b0b9b736820e95bfcef Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 21:29:06 +0400 Subject: [PATCH 15/19] Format source code Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 24 +- keras/src/backend/openvino/image.py | 36 ++- keras/src/backend/openvino/linalg.py | 36 +-- keras/src/backend/openvino/math.py | 68 ++---- keras/src/backend/openvino/numpy.py | 326 +++++++------------------- keras/src/backend/openvino/random.py | 12 +- keras/src/backend/openvino/rnn.py | 34 ++- keras/src/backend/openvino/trainer.py | 91 +++---- keras/src/utils/backend_utils.py | 1 + 9 files changed, 204 insertions(+), 424 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index ca4ceb7217f..628c7049229 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -127,15 +127,11 @@ def shape(x): def cast(x, dtype): - raise NotImplementedError( - "`cast` is not supported with openvino backend" - ) + raise NotImplementedError("`cast` is not supported with openvino backend") def cond(pred, true_fn, false_fn): - raise NotImplementedError( - "`cond` is not supported with openvino backend" - ) + raise NotImplementedError("`cond` is not supported with openvino backend") def vectorized_map(function, elements): @@ -205,9 +201,7 @@ def convert_numpy_to_keras_tensor(x): def scan(f, init, xs=None, length=None, reverse=False, unroll=1): - raise NotImplementedError( - "`scan` is not supported with openvino backend" - ) + raise NotImplementedError("`scan` is not supported with openvino backend") def scatter(indices, values, shape): @@ -223,9 +217,7 @@ def scatter_update(inputs, indices, updates): def slice(inputs, start_indices, lengths): - raise NotImplementedError( - "`slice` is not supported with openvino backend" - ) + raise NotImplementedError("`slice` is not supported with openvino backend") def slice_update(inputs, start_indices, updates): @@ -235,10 +227,10 @@ def slice_update(inputs, start_indices, updates): def while_loop( - cond, - body, - loop_vars, - maximum_iterations=None, + cond, + body, + loop_vars, + maximum_iterations=None, ): raise NotImplementedError( "`while_loop` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py index fa76f89b35c..0591382f1f6 100644 --- a/keras/src/backend/openvino/image.py +++ b/keras/src/backend/openvino/image.py @@ -9,28 +9,26 @@ def rgb_to_grayscale(image, data_format="channels_last"): def resize( - image, - size, - interpolation="bilinear", - antialias=False, - crop_to_aspect_ratio=False, - pad_to_aspect_ratio=False, - fill_mode="constant", - fill_value=0.0, - data_format="channels_last", + image, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format="channels_last", ): - raise NotImplementedError( - "`resize` is not supported with openvino backend" - ) + raise NotImplementedError("`resize` is not supported with openvino backend") def affine_transform( - image, - transform, - interpolation="bilinear", - fill_mode="constant", - fill_value=0, - data_format="channels_last", + image, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format="channels_last", ): raise NotImplementedError( "`affine_transform` is not supported with openvino backend" @@ -38,7 +36,7 @@ def affine_transform( def map_coordinates( - input, coordinates, order, fill_mode="constant", fill_value=0.0 + input, coordinates, order, fill_mode="constant", fill_value=0.0 ): raise NotImplementedError( "`map_coordinates` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 1fb1e6400b7..49edc4b2310 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -11,27 +11,19 @@ def cholesky(a): def det(a): - raise NotImplementedError( - "`det` is not supported with openvino backend" - ) + raise NotImplementedError("`det` is not supported with openvino backend") def eig(a): - raise NotImplementedError( - "`eig` is not supported with openvino backend" - ) + raise NotImplementedError("`eig` is not supported with openvino backend") def eigh(a): - raise NotImplementedError( - "`eigh` is not supported with openvino backend" - ) + raise NotImplementedError("`eigh` is not supported with openvino backend") def inv(a): - raise NotImplementedError( - "`inv` is not supported with openvino backend" - ) + raise NotImplementedError("`inv` is not supported with openvino backend") def lu_factor(a): @@ -41,21 +33,15 @@ def lu_factor(a): def norm(x, ord=None, axis=None, keepdims=False): - raise NotImplementedError( - "`norm` is not supported with openvino backend" - ) + raise NotImplementedError("`norm` is not supported with openvino backend") def qr(x, mode="reduced"): - raise NotImplementedError( - "`qr` is not supported with openvino backend" - ) + raise NotImplementedError("`qr` is not supported with openvino backend") def solve(a, b): - raise NotImplementedError( - "`solve` is not supported with openvino backend" - ) + raise NotImplementedError("`solve` is not supported with openvino backend") def solve_triangular(a, b, lower=False): @@ -65,12 +51,8 @@ def solve_triangular(a, b, lower=False): def svd(x, full_matrices=True, compute_uv=True): - raise NotImplementedError( - "`svd` is not supported with openvino backend" - ) + raise NotImplementedError("`svd` is not supported with openvino backend") def lstsq(a, b, rcond=None): - raise NotImplementedError( - "`lstsq` is not supported with openvino backend" - ) + raise NotImplementedError("`lstsq` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py index d9aaefb86e8..a149d09bd56 100644 --- a/keras/src/backend/openvino/math.py +++ b/keras/src/backend/openvino/math.py @@ -16,9 +16,7 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False): def top_k(x, k, sorted=False): - raise NotImplementedError( - "`top_k` is not supported with openvino backend" - ) + raise NotImplementedError("`top_k` is not supported with openvino backend") def in_top_k(targets, predictions, k): @@ -34,9 +32,7 @@ def logsumexp(x, axis=None, keepdims=False): def qr(x, mode="reduced"): - raise NotImplementedError( - "`qr` is not supported with openvino backend" - ) + raise NotImplementedError("`qr` is not supported with openvino backend") def extract_sequences(x, sequence_length, sequence_stride): @@ -46,76 +42,54 @@ def extract_sequences(x, sequence_length, sequence_stride): def fft(x): - raise NotImplementedError( - "`fft` is not supported with openvino backend" - ) + raise NotImplementedError("`fft` is not supported with openvino backend") def fft2(x): - raise NotImplementedError( - "`fft2` is not supported with openvino backend" - ) + raise NotImplementedError("`fft2` is not supported with openvino backend") def rfft(x, fft_length=None): - raise NotImplementedError( - "`rfft` is not supported with openvino backend" - ) + raise NotImplementedError("`rfft` is not supported with openvino backend") def irfft(x, fft_length=None): - raise NotImplementedError( - "`irfft` is not supported with openvino backend" - ) + raise NotImplementedError("`irfft` is not supported with openvino backend") def stft( - x, sequence_length, sequence_stride, fft_length, window="hann", center=True + x, sequence_length, sequence_stride, fft_length, window="hann", center=True ): - raise NotImplementedError( - "`stft` is not supported with openvino backend" - ) + raise NotImplementedError("`stft` is not supported with openvino backend") def istft( - x, - sequence_length, - sequence_stride, - fft_length, - length=None, - window="hann", - center=True, + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, ): - raise NotImplementedError( - "`istft` is not supported with openvino backend" - ) + raise NotImplementedError("`istft` is not supported with openvino backend") def rsqrt(x): - raise NotImplementedError( - "`rsqrt` is not supported with openvino backend" - ) + raise NotImplementedError("`rsqrt` is not supported with openvino backend") def erf(x): - raise NotImplementedError( - "`erf` is not supported with openvino backend" - ) + raise NotImplementedError("`erf` is not supported with openvino backend") def erfinv(x): - raise NotImplementedError( - "`erfinv` is not supported with openvino backend" - ) + raise NotImplementedError("`erfinv` is not supported with openvino backend") def solve(a, b): - raise NotImplementedError( - "`solve` is not supported with openvino backend" - ) + raise NotImplementedError("`solve` is not supported with openvino backend") def norm(x, ord=None, axis=None, keepdims=False): - raise NotImplementedError( - "`norm` is not supported with openvino backend" - ) + raise NotImplementedError("`norm` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index b987e9326f6..b2fe10bf20e 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -34,9 +34,7 @@ def add(x1, x2): def einsum(subscripts, *operands, **kwargs): - raise NotImplementedError( - "`einsum` is not supported with openvino backend" - ) + raise NotImplementedError("`einsum` is not supported with openvino backend") def subtract(x1, x2): @@ -45,9 +43,7 @@ def subtract(x1, x2): def matmul(x1, x2): - raise NotImplementedError( - "`matmul` is not supported with openvino backend" - ) + raise NotImplementedError("`matmul` is not supported with openvino backend") def multiply(x1, x2): @@ -56,27 +52,19 @@ def multiply(x1, x2): def mean(x, axis=None, keepdims=False): - raise NotImplementedError( - "`mean` is not supported with openvino backend" - ) + raise NotImplementedError("`mean` is not supported with openvino backend") def max(x, axis=None, keepdims=False, initial=None): - raise NotImplementedError( - "`max` is not supported with openvino backend" - ) + raise NotImplementedError("`max` is not supported with openvino backend") def ones(shape, dtype=None): - raise NotImplementedError( - "`ones` is not supported with openvino backend" - ) + raise NotImplementedError("`ones` is not supported with openvino backend") def zeros(shape, dtype=None): - raise NotImplementedError( - "`zeros` is not supported with openvino backend" - ) + raise NotImplementedError("`zeros` is not supported with openvino backend") def absolute(x): @@ -88,45 +76,31 @@ def abs(x): def all(x, axis=None, keepdims=False): - raise NotImplementedError( - "`all` is not supported with openvino backend" - ) + raise NotImplementedError("`all` is not supported with openvino backend") def any(x, axis=None, keepdims=False): - raise NotImplementedError( - "`any` is not supported with openvino backend" - ) + raise NotImplementedError("`any` is not supported with openvino backend") def amax(x, axis=None, keepdims=False): - raise NotImplementedError( - "`amax` is not supported with openvino backend" - ) + raise NotImplementedError("`amax` is not supported with openvino backend") def amin(x, axis=None, keepdims=False): - raise NotImplementedError( - "`amin` is not supported with openvino backend" - ) + raise NotImplementedError("`amin` is not supported with openvino backend") def append(x1, x2, axis=None): - raise NotImplementedError( - "`append` is not supported with openvino backend" - ) + raise NotImplementedError("`append` is not supported with openvino backend") def arange(start, stop=None, step=None, dtype=None): - raise NotImplementedError( - "`arange` is not supported with openvino backend" - ) + raise NotImplementedError("`arange` is not supported with openvino backend") def arccos(x): - raise NotImplementedError( - "`arccos` is not supported with openvino backend" - ) + raise NotImplementedError("`arccos` is not supported with openvino backend") def arccosh(x): @@ -136,9 +110,7 @@ def arccosh(x): def arcsin(x): - raise NotImplementedError( - "`arcsin` is not supported with openvino backend" - ) + raise NotImplementedError("`arcsin` is not supported with openvino backend") def arcsinh(x): @@ -148,9 +120,7 @@ def arcsinh(x): def arctan(x): - raise NotImplementedError( - "`arctan` is not supported with openvino backend" - ) + raise NotImplementedError("`arctan` is not supported with openvino backend") def arctan2(x1, x2): @@ -166,15 +136,11 @@ def arctanh(x): def argmax(x, axis=None, keepdims=False): - raise NotImplementedError( - "`argmax` is not supported with openvino backend" - ) + raise NotImplementedError("`argmax` is not supported with openvino backend") def argmin(x, axis=None, keepdims=False): - raise NotImplementedError( - "`argmin` is not supported with openvino backend" - ) + raise NotImplementedError("`argmin` is not supported with openvino backend") def argsort(x, axis=-1): @@ -206,15 +172,11 @@ def broadcast_to(x, shape): def ceil(x): - raise NotImplementedError( - "`ceil` is not supported with openvino backend" - ) + raise NotImplementedError("`ceil` is not supported with openvino backend") def clip(x, x_min, x_max): - raise NotImplementedError( - "`clip` is not supported with openvino backend" - ) + raise NotImplementedError("`clip` is not supported with openvino backend") def concatenate(xs, axis=0): @@ -230,27 +192,19 @@ def conjugate(x): def conj(x): - raise NotImplementedError( - "`conj` is not supported with openvino backend" - ) + raise NotImplementedError("`conj` is not supported with openvino backend") def copy(x): - raise NotImplementedError( - "`copy` is not supported with openvino backend" - ) + raise NotImplementedError("`copy` is not supported with openvino backend") def cos(x): - raise NotImplementedError( - "`cos` is not supported with openvino backend" - ) + raise NotImplementedError("`cos` is not supported with openvino backend") def cosh(x): - raise NotImplementedError( - "`cosh` is not supported with openvino backend" - ) + raise NotImplementedError("`cosh` is not supported with openvino backend") def count_nonzero(x, axis=None): @@ -260,9 +214,7 @@ def count_nonzero(x, axis=None): def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): - raise NotImplementedError( - "`cross` is not supported with openvino backend" - ) + raise NotImplementedError("`cross` is not supported with openvino backend") def cumprod(x, axis=None, dtype=None): @@ -272,15 +224,11 @@ def cumprod(x, axis=None, dtype=None): def cumsum(x, axis=None, dtype=None): - raise NotImplementedError( - "`cumsum` is not supported with openvino backend" - ) + raise NotImplementedError("`cumsum` is not supported with openvino backend") def diag(x, k=0): - raise NotImplementedError( - "`diag` is not supported with openvino backend" - ) + raise NotImplementedError("`diag` is not supported with openvino backend") def diagonal(x, offset=0, axis1=0, axis2=1): @@ -290,9 +238,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1): def diff(a, n=1, axis=-1): - raise NotImplementedError( - "`diff` is not supported with openvino backend" - ) + raise NotImplementedError("`diff` is not supported with openvino backend") def digitize(x, bins): @@ -302,27 +248,19 @@ def digitize(x, bins): def dot(x, y): - raise NotImplementedError( - "`dot` is not supported with openvino backend" - ) + raise NotImplementedError("`dot` is not supported with openvino backend") def empty(shape, dtype=None): - raise NotImplementedError( - "`empty` is not supported with openvino backend" - ) + raise NotImplementedError("`empty` is not supported with openvino backend") def equal(x1, x2): - raise NotImplementedError( - "`equal` is not supported with openvino backend" - ) + raise NotImplementedError("`equal` is not supported with openvino backend") def exp(x): - raise NotImplementedError( - "`exp` is not supported with openvino backend" - ) + raise NotImplementedError("`exp` is not supported with openvino backend") def expand_dims(x, axis): @@ -332,27 +270,19 @@ def expand_dims(x, axis): def expm1(x): - raise NotImplementedError( - "`expm1` is not supported with openvino backend" - ) + raise NotImplementedError("`expm1` is not supported with openvino backend") def flip(x, axis=None): - raise NotImplementedError( - "`flip` is not supported with openvino backend" - ) + raise NotImplementedError("`flip` is not supported with openvino backend") def floor(x): - raise NotImplementedError( - "`floor` is not supported with openvino backend" - ) + raise NotImplementedError("`floor` is not supported with openvino backend") def full(shape, fill_value, dtype=None): - raise NotImplementedError( - "`full` is not supported with openvino backend" - ) + raise NotImplementedError("`full` is not supported with openvino backend") def full_like(x, fill_value, dtype=None): @@ -374,9 +304,7 @@ def greater_equal(x1, x2): def hstack(xs): - raise NotImplementedError( - "`hstack` is not supported with openvino backend" - ) + raise NotImplementedError("`hstack` is not supported with openvino backend") def identity(n, dtype=None): @@ -386,9 +314,7 @@ def identity(n, dtype=None): def imag(x): - raise NotImplementedError( - "`imag` is not supported with openvino backend" - ) + raise NotImplementedError("`imag` is not supported with openvino backend") def isclose(x1, x2): @@ -404,21 +330,15 @@ def isfinite(x): def isinf(x): - raise NotImplementedError( - "`isinf` is not supported with openvino backend" - ) + raise NotImplementedError("`isinf` is not supported with openvino backend") def isnan(x): - raise NotImplementedError( - "`isnan` is not supported with openvino backend" - ) + raise NotImplementedError("`isnan` is not supported with openvino backend") def less(x1, x2): - raise NotImplementedError( - "`less` is not supported with openvino backend" - ) + raise NotImplementedError("`less` is not supported with openvino backend") def less_equal(x1, x2): @@ -428,7 +348,7 @@ def less_equal(x1, x2): def linspace( - start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): raise NotImplementedError( "`linspace` is not supported with openvino backend" @@ -436,27 +356,19 @@ def linspace( def log(x): - raise NotImplementedError( - "`log` is not supported with openvino backend" - ) + raise NotImplementedError("`log` is not supported with openvino backend") def log10(x): - raise NotImplementedError( - "`log10` is not supported with openvino backend" - ) + raise NotImplementedError("`log10` is not supported with openvino backend") def log1p(x): - raise NotImplementedError( - "`log1p` is not supported with openvino backend" - ) + raise NotImplementedError("`log1p` is not supported with openvino backend") def log2(x): - raise NotImplementedError( - "`log2` is not supported with openvino backend" - ) + raise NotImplementedError("`log2` is not supported with openvino backend") def logaddexp(x1, x2): @@ -496,9 +408,7 @@ def maximum(x1, x2): def median(x, axis=None, keepdims=False): - raise NotImplementedError( - "`median` is not supported with openvino backend" - ) + raise NotImplementedError("`median` is not supported with openvino backend") def meshgrid(*x, indexing="xy"): @@ -508,9 +418,7 @@ def meshgrid(*x, indexing="xy"): def min(x, axis=None, keepdims=False, initial=None): - raise NotImplementedError( - "`min` is not supported with openvino backend" - ) + raise NotImplementedError("`min` is not supported with openvino backend") def minimum(x1, x2): @@ -520,9 +428,7 @@ def minimum(x1, x2): def mod(x1, x2): - raise NotImplementedError( - "`mod` is not supported with openvino backend" - ) + raise NotImplementedError("`mod` is not supported with openvino backend") def moveaxis(x, source, destination): @@ -538,9 +444,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): def ndim(x): - raise NotImplementedError( - "`ndim` is not supported with openvino backend" - ) + raise NotImplementedError("`ndim` is not supported with openvino backend") def nonzero(x): @@ -568,21 +472,15 @@ def ones_like(x, dtype=None): def outer(x1, x2): - raise NotImplementedError( - "`outer` is not supported with openvino backend" - ) + raise NotImplementedError("`outer` is not supported with openvino backend") def pad(x, pad_width, mode="constant", constant_values=None): - raise NotImplementedError( - "`pad` is not supported with openvino backend" - ) + raise NotImplementedError("`pad` is not supported with openvino backend") def prod(x, axis=None, keepdims=False, dtype=None): - raise NotImplementedError( - "`prod` is not supported with openvino backend" - ) + raise NotImplementedError("`prod` is not supported with openvino backend") def quantile(x, q, axis=None, method="linear", keepdims=False): @@ -592,15 +490,11 @@ def quantile(x, q, axis=None, method="linear", keepdims=False): def ravel(x): - raise NotImplementedError( - "`ravel` is not supported with openvino backend" - ) + raise NotImplementedError("`ravel` is not supported with openvino backend") def real(x): - raise NotImplementedError( - "`real` is not supported with openvino backend" - ) + raise NotImplementedError("`real` is not supported with openvino backend") def reciprocal(x): @@ -610,9 +504,7 @@ def reciprocal(x): def repeat(x, repeats, axis=None): - raise NotImplementedError( - "`repeat` is not supported with openvino backend" - ) + raise NotImplementedError("`repeat` is not supported with openvino backend") def reshape(x, newshape): @@ -622,57 +514,39 @@ def reshape(x, newshape): def roll(x, shift, axis=None): - raise NotImplementedError( - "`roll` is not supported with openvino backend" - ) + raise NotImplementedError("`roll` is not supported with openvino backend") def sign(x): - raise NotImplementedError( - "`sign` is not supported with openvino backend" - ) + raise NotImplementedError("`sign` is not supported with openvino backend") def sin(x): - raise NotImplementedError( - "`sin` is not supported with openvino backend" - ) + raise NotImplementedError("`sin` is not supported with openvino backend") def sinh(x): - raise NotImplementedError( - "`sinh` is not supported with openvino backend" - ) + raise NotImplementedError("`sinh` is not supported with openvino backend") def size(x): - raise NotImplementedError( - "`size` is not supported with openvino backend" - ) + raise NotImplementedError("`size` is not supported with openvino backend") def sort(x, axis=-1): - raise NotImplementedError( - "`sort` is not supported with openvino backend" - ) + raise NotImplementedError("`sort` is not supported with openvino backend") def split(x, indices_or_sections, axis=0): - raise NotImplementedError( - "`split` is not supported with openvino backend" - ) + raise NotImplementedError("`split` is not supported with openvino backend") def stack(x, axis=0): - raise NotImplementedError( - "`stack` is not supported with openvino backend" - ) + raise NotImplementedError("`stack` is not supported with openvino backend") def std(x, axis=None, keepdims=False): - raise NotImplementedError( - "`std` is not supported with openvino backend" - ) + raise NotImplementedError("`std` is not supported with openvino backend") def swapaxes(x, axis1, axis2): @@ -682,9 +556,7 @@ def swapaxes(x, axis1, axis2): def take(x, indices, axis=None): - raise NotImplementedError( - "`take` is not supported with openvino backend" - ) + raise NotImplementedError("`take` is not supported with openvino backend") def take_along_axis(x, indices, axis=None): @@ -694,15 +566,11 @@ def take_along_axis(x, indices, axis=None): def tan(x): - raise NotImplementedError( - "`tan` is not supported with openvino backend" - ) + raise NotImplementedError("`tan` is not supported with openvino backend") def tanh(x): - raise NotImplementedError( - "`tanh` is not supported with openvino backend" - ) + raise NotImplementedError("`tanh` is not supported with openvino backend") def tensordot(x1, x2, axes=2): @@ -712,51 +580,35 @@ def tensordot(x1, x2, axes=2): def round(x, decimals=0): - raise NotImplementedError( - "`round` is not supported with openvino backend" - ) + raise NotImplementedError("`round` is not supported with openvino backend") def tile(x, repeats): - raise NotImplementedError( - "`tile` is not supported with openvino backend" - ) + raise NotImplementedError("`tile` is not supported with openvino backend") def trace(x, offset=0, axis1=0, axis2=1): - raise NotImplementedError( - "`trace` is not supported with openvino backend" - ) + raise NotImplementedError("`trace` is not supported with openvino backend") def tri(N, M=None, k=0, dtype=None): - raise NotImplementedError( - "`tri` is not supported with openvino backend" - ) + raise NotImplementedError("`tri` is not supported with openvino backend") def tril(x, k=0): - raise NotImplementedError( - "`tril` is not supported with openvino backend" - ) + raise NotImplementedError("`tril` is not supported with openvino backend") def triu(x, k=0): - raise NotImplementedError( - "`triu` is not supported with openvino backend" - ) + raise NotImplementedError("`triu` is not supported with openvino backend") def vdot(x1, x2): - raise NotImplementedError( - "`vdot` is not supported with openvino backend" - ) + raise NotImplementedError("`vdot` is not supported with openvino backend") def vstack(xs): - raise NotImplementedError( - "`vstack` is not supported with openvino backend" - ) + raise NotImplementedError("`vstack` is not supported with openvino backend") def vectorize(pyfunc, *, excluded=None, signature=None): @@ -766,9 +618,7 @@ def vectorize(pyfunc, *, excluded=None, signature=None): def where(condition, x1, x2): - raise NotImplementedError( - "`where` is not supported with openvino backend" - ) + raise NotImplementedError("`where` is not supported with openvino backend") def divide(x1, x2): @@ -796,15 +646,11 @@ def negative(x): def square(x): - raise NotImplementedError( - "`square` is not supported with openvino backend" - ) + raise NotImplementedError("`square` is not supported with openvino backend") def sqrt(x): - raise NotImplementedError( - "`sqrt` is not supported with openvino backend" - ) + raise NotImplementedError("`sqrt` is not supported with openvino backend") def squeeze(x, axis=None): @@ -820,21 +666,15 @@ def transpose(x, axes=None): def var(x, axis=None, keepdims=False): - raise NotImplementedError( - "`var` is not supported with openvino backend" - ) + raise NotImplementedError("`var` is not supported with openvino backend") def sum(x, axis=None, keepdims=False): - raise NotImplementedError( - "`sum` is not supported with openvino backend" - ) + raise NotImplementedError("`sum` is not supported with openvino backend") def eye(N, M=None, k=0, dtype=None): - raise NotImplementedError( - "`eye` is not supported with openvino backend" - ) + raise NotImplementedError("`eye` is not supported with openvino backend") def floor_divide(x1, x2): @@ -854,9 +694,7 @@ def correlate(x1, x2, mode="valid"): def select(condlist, choicelist, default=0): - raise NotImplementedError( - "`select` is not supported with openvino backend" - ) + raise NotImplementedError("`select` is not supported with openvino backend") def slogdet(x): diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py index c8a4762b24e..4d6565d6f58 100644 --- a/keras/src/backend/openvino/random.py +++ b/keras/src/backend/openvino/random.py @@ -5,9 +5,7 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): - raise NotImplementedError( - "`normal` is not supported with openvino backend" - ) + raise NotImplementedError("`normal` is not supported with openvino backend") def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): @@ -47,9 +45,7 @@ def shuffle(x, axis=0, seed=None): def gamma(shape, alpha, dtype=None, seed=None): - raise NotImplementedError( - "`gamma` is not supported with openvino backend" - ) + raise NotImplementedError("`gamma` is not supported with openvino backend") def binomial(shape, counts, probabilities, dtype=None, seed=None): @@ -59,6 +55,4 @@ def binomial(shape, counts, probabilities, dtype=None, seed=None): def beta(shape, alpha, beta, dtype=None, seed=None): - raise NotImplementedError( - "`beta` is not supported with openvino backend" - ) + raise NotImplementedError("`beta` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py index d1f9abc6761..a1421f87259 100644 --- a/keras/src/backend/openvino/rnn.py +++ b/keras/src/backend/openvino/rnn.py @@ -2,33 +2,27 @@ def rnn( - step_function, - inputs, - initial_states, - go_backwards=False, - mask=None, - constants=None, - unroll=False, - input_length=None, - time_major=False, - zero_output_for_mask=False, - return_all_outputs=True, + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, ): - raise NotImplementedError( - "`rnn` is not supported with openvino backend" - ) + raise NotImplementedError("`rnn` is not supported with openvino backend") def lstm(*args, **kwargs): - raise NotImplementedError( - "`lstm` is not supported with openvino backend" - ) + raise NotImplementedError("`lstm` is not supported with openvino backend") def gru(*args, **kwargs): - raise NotImplementedError( - "`gru` is not supported with openvino backend" - ) + raise NotImplementedError("`gru` is not supported with openvino backend") def unstack(x, axis=0): diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index c06988b7ff8..b2c495c9dd5 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -38,7 +38,9 @@ def predict_step(self, data): ov_compiled_model = self._get_compiled_model() flatten_x = tree.flatten(x) y_pred = ov_compiled_model(flatten_x) - y_pred = self._unpack_singleton(tree.pack_sequence_as(self._outputs_struct, y_pred.to_tuple())) + y_pred = self._unpack_singleton( + tree.pack_sequence_as(self._outputs_struct, y_pred.to_tuple()) + ) return y_pred def make_test_function(self, force=False): @@ -62,7 +64,10 @@ def multi_test_steps(data): self.test_function = test_step def _get_compiled_model(self): - if self.ov_compiled_model is not None and get_device() == self.ov_device: + if ( + self.ov_compiled_model is not None + and get_device() == self.ov_device + ): return self.ov_compiled_model # prepare compiled model from scratch @@ -78,7 +83,9 @@ def _get_compiled_model(self): param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) ov_inputs.append(param) # build OpenVINO graph ov.Model - ov_outputs = self._run_through_graph(ov_inputs, operation_fn=lambda op: op) + ov_outputs = self._run_through_graph( + ov_inputs, operation_fn=lambda op: op + ) ov_outputs = tree.flatten(ov_outputs) ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) self.ov_compiled_model = ov.compile_model(ov_model, get_device()) @@ -113,23 +120,23 @@ def multi_predict_steps(data): self.predict_function = predict_step def fit( - self, - x=None, - y=None, - batch_size=None, - epochs=1, - verbose="auto", - callbacks=None, - validation_split=0.0, - validation_data=None, - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - validation_batch_size=None, - validation_freq=1, + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, ): raise NotImplementedError( "`fit` is not supported with openvino backend" @@ -137,7 +144,7 @@ def fit( @traceback_utils.filter_traceback def predict( - self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None ): # Create an iterator that yields batches of input data. epoch_iterator = EpochIterator( @@ -191,39 +198,39 @@ def append_to_outputs(batch_outputs, outputs): @traceback_utils.filter_traceback def evaluate( - self, - x=None, - y=None, - batch_size=None, - verbose="auto", - sample_weight=None, - steps=None, - callbacks=None, - return_dict=False, - **kwargs, + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, ): raise NotImplementedError( "`evaluate` is not supported with openvino backend" ) def train_on_batch( - self, - x, - y=None, - sample_weight=None, - class_weight=None, - return_dict=False, + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, ): raise NotImplementedError( "`train_on_batch` is not supported with openvino backend" ) def test_on_batch( - self, - x, - y=None, - sample_weight=None, - return_dict=False, + self, + x, + y=None, + sample_weight=None, + return_dict=False, ): raise NotImplementedError( "`test_on_batch` is not supported with openvino backend" diff --git a/keras/src/utils/backend_utils.py b/keras/src/utils/backend_utils.py index cd07121bf35..05e637d444b 100644 --- a/keras/src/utils/backend_utils.py +++ b/keras/src/utils/backend_utils.py @@ -96,6 +96,7 @@ def __getattr__(self, name): module = importlib.import_module("keras.src.backend.openvino") return getattr(module, name) + @keras_export("keras.config.set_backend") def set_backend(backend): """Reload the backend (and the Keras package). From 42364c252594900d33b1e9e41560de2c4b60f91e Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 21:32:19 +0400 Subject: [PATCH 16/19] Format the rest of source code Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/nn.py | 106 +++++++++++++++---------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index ac9ac7ac31d..282f6ade9d7 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -10,9 +10,7 @@ def relu(x): def relu6(x): - raise NotImplementedError( - "`relu6` is not supported with openvino backend" - ) + raise NotImplementedError("`relu6` is not supported with openvino backend") def sigmoid(x): @@ -62,9 +60,9 @@ def elu(x, alpha=1.0): def selu( - x, - alpha=1.6732632423543772848170429916717, - scale=1.0507009873554804934193349852946, + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, ): return opset14.selu(x, alpha, scale) @@ -85,11 +83,11 @@ def log_softmax(x, axis=None): def max_pool( - inputs, - pool_size, - strides=None, - padding="valid", - data_format=None, + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, ): raise NotImplementedError( "`max_pool` is not supported with openvino backend" @@ -97,11 +95,11 @@ def max_pool( def average_pool( - inputs, - pool_size, - strides, - padding, - data_format=None, + inputs, + pool_size, + strides, + padding, + data_format=None, ): raise NotImplementedError( "`average_pool` is not supported with openvino backend" @@ -109,25 +107,23 @@ def average_pool( def conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): - raise NotImplementedError( - "`conv` is not supported with openvino backend" - ) + raise NotImplementedError("`conv` is not supported with openvino backend") def depthwise_conv( - inputs, - kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): raise NotImplementedError( "`depthwise_conv` is not supported with openvino backend" @@ -135,13 +131,13 @@ def depthwise_conv( def separable_conv( - inputs, - depthwise_kernel, - pointwise_kernel, - strides=1, - padding="valid", - data_format=None, - dilation_rate=1, + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, ): raise NotImplementedError( "`separable_conv` is not supported with openvino backend" @@ -149,13 +145,13 @@ def separable_conv( def conv_transpose( - inputs, - kernel, - strides=1, - padding="valid", - output_padding=None, - data_format=None, - dilation_rate=1, + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, ): raise NotImplementedError( "`conv_transpose` is not supported with openvino backend" @@ -199,7 +195,7 @@ def moments(x, axes, keepdims=False, synchronized=False): def batch_normalization( - x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 ): raise NotImplementedError( "`batch_normalization` is not supported with openvino backend" @@ -213,13 +209,13 @@ def ctc_loss(target, output, target_length, output_length, mask_index=0): def ctc_decode( - inputs, - sequence_lengths, - strategy="greedy", - beam_width=100, - top_paths=1, - merge_repeated=True, - mask_index=0, + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, ): raise NotImplementedError( "`ctc_decode` is not supported with openvino backend" @@ -227,6 +223,4 @@ def ctc_decode( def psnr(x1, x2, max_val): - raise NotImplementedError( - "`psnr` is not supported with openvino backend" - ) + raise NotImplementedError("`psnr` is not supported with openvino backend") From 954ed1fe472a23886c4f787c30a8f19eeb3ac6e2 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 21:47:57 +0400 Subject: [PATCH 17/19] Continue format adjustment Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/core.py | 5 ++--- keras/src/backend/openvino/image.py | 4 ---- keras/src/backend/openvino/linalg.py | 6 ------ keras/src/backend/openvino/math.py | 5 ----- keras/src/backend/openvino/nn.py | 6 ++---- keras/src/backend/openvino/numpy.py | 7 ++----- keras/src/backend/openvino/rnn.py | 3 --- keras/src/backend/openvino/trainer.py | 2 -- 8 files changed, 6 insertions(+), 32 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 628c7049229..684101751f5 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -7,11 +7,9 @@ from keras.src.backend.common import KerasVariable from keras.src.backend.common import global_state from keras.src.backend.common import standardize_dtype -from keras.src.backend.common.backend_utils import slice_along_axis from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope -from keras.src.backend.common.symbolic_scope import SymbolicScope SUPPORTS_SPARSE_TENSORS = False @@ -84,7 +82,8 @@ def _direct_assign(self, value): def _convert_to_tensor(self, value, dtype=None): raise NotImplementedError( - "`Variable._convert_to_tensor` is not supported with openvino backend" + "`Variable._convert_to_tensor` is not supported " + "with openvino backend" ) # Overload native accessor. diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py index 0591382f1f6..aa728845d22 100644 --- a/keras/src/backend/openvino/image.py +++ b/keras/src/backend/openvino/image.py @@ -1,7 +1,3 @@ -from keras.src import backend -from keras.src.utils.module_utils import scipy - - def rgb_to_grayscale(image, data_format="channels_last"): raise NotImplementedError( "`rgb_to_grayscale` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 49edc4b2310..3703bd83a0c 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,9 +1,3 @@ -import scipy.linalg as sl - -from keras.src.backend import standardize_dtype -from keras.src.backend.common import dtypes - - def cholesky(a): raise NotImplementedError( "`cholesky` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py index a149d09bd56..d892af8a808 100644 --- a/keras/src/backend/openvino/math.py +++ b/keras/src/backend/openvino/math.py @@ -1,8 +1,3 @@ -from keras.src.backend import standardize_dtype -from keras.src.backend.common import dtypes -from keras.src.utils.module_utils import scipy - - def segment_sum(data, segment_ids, num_segments=None, sorted=False): raise NotImplementedError( "`segment_sum` is not supported with openvino backend" diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 282f6ade9d7..b141219afbc 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -1,9 +1,6 @@ import numpy as np from openvino.runtime import opset14 -from keras.src import backend -from keras.src.utils.module_utils import scipy - def relu(x): return opset14.relu(x) @@ -178,7 +175,8 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): raise NotImplementedError( - "`sparse_categorical_crossentropy` is not supported with openvino backend" + "`sparse_categorical_crossentropy` is not supported " + "with openvino backend" ) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index b2fe10bf20e..47699588115 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1,10 +1,6 @@ from openvino.runtime import opset14 -from keras.src import tree -from keras.src.backend import config -from keras.src.backend import standardize_dtype from keras.src.backend.common import dtypes -from keras.src.backend.common.backend_utils import standardize_axis_for_numpy from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import convert_to_tensor from keras.src.backend.openvino.core import ov_to_keras_type @@ -15,7 +11,8 @@ def _align_operand_types(x1, x2, op_name): x2_type = x2.element_type if x1_type.is_dynamic() or x2_type.is_dynamic(): raise ValueError( - f"'{op_name}' operation is not supported for dynamic operand type with openvino backend" + f"'{op_name}' operation is not supported for dynamic operand type " + "with openvino backend" ) x1_type = ov_to_keras_type(x1_type) x2_type = ov_to_keras_type(x2_type) diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py index a1421f87259..70190fc47c8 100644 --- a/keras/src/backend/openvino/rnn.py +++ b/keras/src/backend/openvino/rnn.py @@ -1,6 +1,3 @@ -from keras.src import tree - - def rnn( step_function, inputs, diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index b2c495c9dd5..364770e0feb 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -5,8 +5,6 @@ from keras.src import backend from keras.src import callbacks as callbacks_module from keras.src import tree -from keras.src.backend.common import standardize_dtype -from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.openvino.core import OPENVINO_DTYPES from keras.src.backend.openvino.core import get_device from keras.src.trainers import trainer as base_trainer From facde414c3b87e30fbc7c387beb8f625d64ce601 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 25 Sep 2024 21:57:10 +0400 Subject: [PATCH 18/19] Add OpenVINO dependency Signed-off-by: Kazantsev, Roman --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 14ba558acea..3f83fd5e06b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,8 @@ torchvision>=0.16.0 jax[cpu] flax +# OpenVINO +openvino~=2024.4.0 + # Common deps. -r requirements-common.txt From fa6b4616be546dd7cce526bf270451e17f9c6d2a Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Fri, 1 Nov 2024 08:41:17 +0400 Subject: [PATCH 19/19] Fix inference using OV backend Signed-off-by: Kazantsev, Roman --- keras/src/backend/openvino/__init__.py | 2 ++ keras/src/backend/openvino/core.py | 3 +++ keras/src/backend/openvino/trainer.py | 6 ++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py index 8ed2ad1f594..d9148e0a049 100644 --- a/keras/src/backend/openvino/__init__.py +++ b/keras/src/backend/openvino/__init__.py @@ -1,3 +1,4 @@ +from keras.src.backend.common.name_scope import name_scope from keras.src.backend.openvino import core from keras.src.backend.openvino import image from keras.src.backend.openvino import linalg @@ -5,6 +6,7 @@ from keras.src.backend.openvino import nn from keras.src.backend.openvino import numpy from keras.src.backend.openvino import random +from keras.src.backend.openvino.core import IS_THREAD_SAFE from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS from keras.src.backend.openvino.core import Variable from keras.src.backend.openvino.core import cast diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 684101751f5..8e6bc154694 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -12,6 +12,7 @@ from keras.src.backend.common.stateless_scope import StatelessScope SUPPORTS_SPARSE_TENSORS = False +IS_THREAD_SAFE = True OPENVINO_DTYPES = { "float16": ov.Type.f16, @@ -116,6 +117,8 @@ def convert_to_numpy(x): def is_tensor(x): + if isinstance(x, (ov.runtime.Output)): + return False if isinstance(x, ov.Tensor): return True return False diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py index 364770e0feb..0c1b094392e 100644 --- a/keras/src/backend/openvino/trainer.py +++ b/keras/src/backend/openvino/trainer.py @@ -71,6 +71,7 @@ def _get_compiled_model(self): # prepare compiled model from scratch del self.ov_compiled_model ov_inputs = [] + parameters = [] for _input in self._inputs: ov_type = OPENVINO_DTYPES[_input.dtype] ov_shape = _input.shape @@ -79,13 +80,14 @@ def _get_compiled_model(self): if ov_shape[i] is None: ov_shape[i] = -1 param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) - ov_inputs.append(param) + parameters.append(param) + ov_inputs.append(param.output(0)) # build OpenVINO graph ov.Model ov_outputs = self._run_through_graph( ov_inputs, operation_fn=lambda op: op ) ov_outputs = tree.flatten(ov_outputs) - ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) + ov_model = ov.Model(results=ov_outputs, parameters=parameters) self.ov_compiled_model = ov.compile_model(ov_model, get_device()) self.ov_device = get_device() return self.ov_compiled_model