diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 4ba5c47725d..35670c98328 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -44,6 +44,10 @@ elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 + distribution_lib = None +elif backend() == "openvino": + from keras.src.backend.openvino import * # noqa: F403 + distribution_lib = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/exports.py b/keras/src/backend/exports.py index 94f8c29abf7..6e5c508c8ac 100644 --- a/keras/src/backend/exports.py +++ b/keras/src/backend/exports.py @@ -16,6 +16,11 @@ BackendVariable = NumpyVariable backend_name_scope = backend.common.name_scope.name_scope +elif backend.backend() == "openvino": + from keras.src.backend.openvino.core import Variable as OpenVINOVariable + + BackendVariable = OpenVINOVariable + backend_name_scope = backend.common.name_scope.name_scope else: raise RuntimeError(f"Invalid backend: {backend.backend()}") diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py new file mode 100644 index 00000000000..8ed2ad1f594 --- /dev/null +++ b/keras/src/backend/openvino/__init__.py @@ -0,0 +1,22 @@ +from keras.src.backend.openvino import core +from keras.src.backend.openvino import image +from keras.src.backend.openvino import linalg +from keras.src.backend.openvino import math +from keras.src.backend.openvino import nn +from keras.src.backend.openvino import numpy +from keras.src.backend.openvino import random +from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.openvino.core import Variable +from keras.src.backend.openvino.core import cast +from keras.src.backend.openvino.core import compute_output_spec +from keras.src.backend.openvino.core import cond +from keras.src.backend.openvino.core import convert_to_numpy +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import is_tensor +from keras.src.backend.openvino.core import random_seed_dtype +from keras.src.backend.openvino.core import shape +from keras.src.backend.openvino.core import vectorized_map +from keras.src.backend.openvino.rnn import cudnn_ok +from keras.src.backend.openvino.rnn import gru +from keras.src.backend.openvino.rnn import lstm +from keras.src.backend.openvino.rnn import rnn diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py new file mode 100644 index 00000000000..684101751f5 --- /dev/null +++ b/keras/src/backend/openvino/core.py @@ -0,0 +1,262 @@ +import contextlib + +import numpy as np +import openvino as ov + +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.stateless_scope import StatelessScope + +SUPPORTS_SPARSE_TENSORS = False + +OPENVINO_DTYPES = { + "float16": ov.Type.f16, + "float32": ov.Type.f32, + "float64": ov.Type.f64, + "uint8": ov.Type.u8, + "uint16": ov.Type.u16, + "uint32": ov.Type.u32, + "int8": ov.Type.i8, + "int16": ov.Type.i16, + "int32": ov.Type.i32, + "int64": ov.Type.i64, + "bfloat16": ov.Type.bf16, + "bool": ov.Type.boolean, + "float8_e4m3fn": ov.Type.f8e4m3, + "float8_e5m2": ov.Type.f8e5m2, +} + + +def ov_to_keras_type(ov_type): + for _keras_type, _ov_type in OPENVINO_DTYPES.items(): + if ov_type == _ov_type: + return _keras_type + raise ValueError( + f"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'" + ) + + +@contextlib.contextmanager +def device_scope(device_name): + current_device = _parse_device_input(device_name) + global_state.set_global_attribute("openvino_device", current_device) + + +def get_device(): + device = global_state.get_global_attribute("openvino_device", None) + if device is None: + return "CPU" + return device + + +def _parse_device_input(device_name): + if isinstance(device_name, str): + # We support string value like "cpu:0", "gpu:1", and need to convert + # "gpu" to "cuda" + device_name = device_name.upper() + device_type, _ = device_name.split(":") + return device_type + else: + raise ValueError( + "Invalid value for argument `device_name`. " + "Expected a string like 'gpu:0' or 'cpu'. " + f"Received: device_name='{device_name}'" + ) + return device_name + + +class Variable(KerasVariable): + def _initialize(self, value): + raise NotImplementedError( + "`Variable._initialize` is not supported with openvino backend" + ) + + def _direct_assign(self, value): + raise NotImplementedError( + "`Variable._direct_assign` is not supported with openvino backend" + ) + + def _convert_to_tensor(self, value, dtype=None): + raise NotImplementedError( + "`Variable._convert_to_tensor` is not supported " + "with openvino backend" + ) + + # Overload native accessor. + def __array__(self): + raise NotImplementedError( + "`Variable.__array__` is not supported with openvino backend" + ) + + +def convert_to_tensor(x, dtype=None, sparse=None): + if sparse: + raise ValueError("`sparse=True` is not supported with numpy backend") + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, Variable): + if dtype and dtype != x.dtype: + return x.value.astype(dtype) + return x.value + if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16": + return ov.Tensor(np.asarray(x).astype(dtype)) + if dtype is None: + dtype = result_type( + *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] + ) + return ov.Tensor(np.array(x, dtype=dtype)) + + +def convert_to_numpy(x): + return np.array(x) + + +def is_tensor(x): + if isinstance(x, ov.Tensor): + return True + return False + + +def shape(x): + return x.shape + + +def cast(x, dtype): + raise NotImplementedError("`cast` is not supported with openvino backend") + + +def cond(pred, true_fn, false_fn): + raise NotImplementedError("`cond` is not supported with openvino backend") + + +def vectorized_map(function, elements): + raise NotImplementedError( + "`vectorized_map` is not supported with openvino backend" + ) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(): + + def has_none_shape(x): + if isinstance(x, KerasTensor): + return None in x.shape + return False + + none_in_shape = any(map(has_none_shape, tree.flatten((args, kwargs)))) + + def convert_keras_tensor_to_numpy(x, fill_value=None): + if isinstance(x, KerasTensor): + shape = list(x.shape) + if fill_value: + for i, e in enumerate(shape): + if e is None: + shape[i] = fill_value + return np.empty( + shape=shape, + dtype=x.dtype, + ) + return x + + args_1, kwargs_1 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=83), + (args, kwargs), + ) + outputs_1 = fn(*args_1, **kwargs_1) + + outputs = outputs_1 + + if none_in_shape: + args_2, kwargs_2 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=89), + (args, kwargs), + ) + outputs_2 = fn(*args_2, **kwargs_2) + + flat_out_1 = tree.flatten(outputs_1) + flat_out_2 = tree.flatten(outputs_2) + + flat_out = [] + for x1, x2 in zip(flat_out_1, flat_out_2): + shape = list(x1.shape) + for i, e in enumerate(x2.shape): + if e != shape[i]: + shape[i] = None + flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype))) + outputs = tree.pack_sequence_as(outputs_1, flat_out) + + def convert_numpy_to_keras_tensor(x): + if is_tensor(x): + return KerasTensor(x.shape, standardize_dtype(x.dtype)) + return x + + output_spec = tree.map_structure(convert_numpy_to_keras_tensor, outputs) + return output_spec + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + raise NotImplementedError("`scan` is not supported with openvino backend") + + +def scatter(indices, values, shape): + raise NotImplementedError( + "`scatter` is not supported with openvino backend" + ) + + +def scatter_update(inputs, indices, updates): + raise NotImplementedError( + "`scatter_update` is not supported with openvino backend" + ) + + +def slice(inputs, start_indices, lengths): + raise NotImplementedError("`slice` is not supported with openvino backend") + + +def slice_update(inputs, start_indices, updates): + raise NotImplementedError( + "`slice_update` is not supported with openvino backend" + ) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + raise NotImplementedError( + "`while_loop` is not supported with openvino backend" + ) + + +def fori_loop(lower, upper, body_fun, init_val): + raise NotImplementedError( + "`fori_loop` is not supported with openvino backend" + ) + + +def stop_gradient(x): + return x + + +def unstack(x, num=None, axis=0): + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) + + +def random_seed_dtype(): + return "uint32" + + +def custom_gradient(fun): + raise NotImplementedError( + "`custom_gradient` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py new file mode 100644 index 00000000000..aa728845d22 --- /dev/null +++ b/keras/src/backend/openvino/image.py @@ -0,0 +1,39 @@ +def rgb_to_grayscale(image, data_format="channels_last"): + raise NotImplementedError( + "`rgb_to_grayscale` is not supported with openvino backend" + ) + + +def resize( + image, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format="channels_last", +): + raise NotImplementedError("`resize` is not supported with openvino backend") + + +def affine_transform( + image, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format="channels_last", +): + raise NotImplementedError( + "`affine_transform` is not supported with openvino backend" + ) + + +def map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + raise NotImplementedError( + "`map_coordinates` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/layer.py b/keras/src/backend/openvino/layer.py new file mode 100644 index 00000000000..334c32958a7 --- /dev/null +++ b/keras/src/backend/openvino/layer.py @@ -0,0 +1,2 @@ +class OpenvinoLayer: + pass diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py new file mode 100644 index 00000000000..3703bd83a0c --- /dev/null +++ b/keras/src/backend/openvino/linalg.py @@ -0,0 +1,52 @@ +def cholesky(a): + raise NotImplementedError( + "`cholesky` is not supported with openvino backend" + ) + + +def det(a): + raise NotImplementedError("`det` is not supported with openvino backend") + + +def eig(a): + raise NotImplementedError("`eig` is not supported with openvino backend") + + +def eigh(a): + raise NotImplementedError("`eigh` is not supported with openvino backend") + + +def inv(a): + raise NotImplementedError("`inv` is not supported with openvino backend") + + +def lu_factor(a): + raise NotImplementedError( + "`lu_factor` is not supported with openvino backend" + ) + + +def norm(x, ord=None, axis=None, keepdims=False): + raise NotImplementedError("`norm` is not supported with openvino backend") + + +def qr(x, mode="reduced"): + raise NotImplementedError("`qr` is not supported with openvino backend") + + +def solve(a, b): + raise NotImplementedError("`solve` is not supported with openvino backend") + + +def solve_triangular(a, b, lower=False): + raise NotImplementedError( + "`solve_triangular` is not supported with openvino backend" + ) + + +def svd(x, full_matrices=True, compute_uv=True): + raise NotImplementedError("`svd` is not supported with openvino backend") + + +def lstsq(a, b, rcond=None): + raise NotImplementedError("`lstsq` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py new file mode 100644 index 00000000000..d892af8a808 --- /dev/null +++ b/keras/src/backend/openvino/math.py @@ -0,0 +1,90 @@ +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + raise NotImplementedError( + "`segment_sum` is not supported with openvino backend" + ) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + raise NotImplementedError( + "`segment_max` is not supported with openvino backend" + ) + + +def top_k(x, k, sorted=False): + raise NotImplementedError("`top_k` is not supported with openvino backend") + + +def in_top_k(targets, predictions, k): + raise NotImplementedError( + "`in_top_k` is not supported with openvino backend" + ) + + +def logsumexp(x, axis=None, keepdims=False): + raise NotImplementedError( + "`logsumexp` is not supported with openvino backend" + ) + + +def qr(x, mode="reduced"): + raise NotImplementedError("`qr` is not supported with openvino backend") + + +def extract_sequences(x, sequence_length, sequence_stride): + raise NotImplementedError( + "`extract_sequences` is not supported with openvino backend" + ) + + +def fft(x): + raise NotImplementedError("`fft` is not supported with openvino backend") + + +def fft2(x): + raise NotImplementedError("`fft2` is not supported with openvino backend") + + +def rfft(x, fft_length=None): + raise NotImplementedError("`rfft` is not supported with openvino backend") + + +def irfft(x, fft_length=None): + raise NotImplementedError("`irfft` is not supported with openvino backend") + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + raise NotImplementedError("`stft` is not supported with openvino backend") + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + raise NotImplementedError("`istft` is not supported with openvino backend") + + +def rsqrt(x): + raise NotImplementedError("`rsqrt` is not supported with openvino backend") + + +def erf(x): + raise NotImplementedError("`erf` is not supported with openvino backend") + + +def erfinv(x): + raise NotImplementedError("`erfinv` is not supported with openvino backend") + + +def solve(a, b): + raise NotImplementedError("`solve` is not supported with openvino backend") + + +def norm(x, ord=None, axis=None, keepdims=False): + raise NotImplementedError("`norm` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py new file mode 100644 index 00000000000..b141219afbc --- /dev/null +++ b/keras/src/backend/openvino/nn.py @@ -0,0 +1,224 @@ +import numpy as np +from openvino.runtime import opset14 + + +def relu(x): + return opset14.relu(x) + + +def relu6(x): + raise NotImplementedError("`relu6` is not supported with openvino backend") + + +def sigmoid(x): + return opset14.sigmoid(x) + + +def tanh(x): + return opset14.tanh(x) + + +def softplus(x): + return opset14.softplus(x) + + +def softsign(x): + return opset14.softsign(x) + + +def silu(x): + return opset14.multiply(x, opset14.sigmoid(x)) + + +def log_sigmoid(x): + raise NotImplementedError( + "`log_sigmoid` is not supported with openvino backend" + ) + + +def leaky_relu(x, negative_slope=0.2): + raise NotImplementedError( + "`leaky_relu` is not supported with openvino backend" + ) + + +def hard_sigmoid(x): + alpha = 1 / np.array(6.0, dtype=np.float32) + beta = np.array(0.5, dtype=np.float32) + return opset14.hard_sigmoid(x, alpha, beta) + + +def hard_silu(x): + return opset14.multiply(x, hard_sigmoid(x)) + + +def elu(x, alpha=1.0): + return opset14.elu(x, alpha) + + +def selu( + x, + alpha=1.6732632423543772848170429916717, + scale=1.0507009873554804934193349852946, +): + return opset14.selu(x, alpha, scale) + + +def gelu(x, approximate=True): + approximate_mode = "erf" + if approximate: + approximate_mode = "tanh" + return opset14.gelu(x, approximate_mode) + + +def softmax(x, axis=None): + return opset14.softmax(x, axis) + + +def log_softmax(x, axis=None): + return opset14.log_softmax(x, axis) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + raise NotImplementedError( + "`max_pool` is not supported with openvino backend" + ) + + +def average_pool( + inputs, + pool_size, + strides, + padding, + data_format=None, +): + raise NotImplementedError( + "`average_pool` is not supported with openvino backend" + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + raise NotImplementedError("`conv` is not supported with openvino backend") + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + raise NotImplementedError( + "`depthwise_conv` is not supported with openvino backend" + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + raise NotImplementedError( + "`separable_conv` is not supported with openvino backend" + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + raise NotImplementedError( + "`conv_transpose` is not supported with openvino backend" + ) + + +def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): + raise NotImplementedError( + "`one_hot` is not supported with openvino backend" + ) + + +def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False): + raise NotImplementedError( + "`multi_hot` is not supported with openvino backend" + ) + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + raise NotImplementedError( + "`categorical_crossentropy` is not supported with openvino backend" + ) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + raise NotImplementedError( + "`sparse_categorical_crossentropy` is not supported " + "with openvino backend" + ) + + +def binary_crossentropy(target, output, from_logits=False): + raise NotImplementedError( + "`binary_crossentropy` is not supported with openvino backend" + ) + + +def moments(x, axes, keepdims=False, synchronized=False): + raise NotImplementedError( + "`moments` is not supported with openvino backend" + ) + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + raise NotImplementedError( + "`batch_normalization` is not supported with openvino backend" + ) + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + raise NotImplementedError( + "`ctc_loss` is not supported with openvino backend" + ) + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + raise NotImplementedError( + "`ctc_decode` is not supported with openvino backend" + ) + + +def psnr(x1, x2, max_val): + raise NotImplementedError("`psnr` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py new file mode 100644 index 00000000000..47699588115 --- /dev/null +++ b/keras/src/backend/openvino/numpy.py @@ -0,0 +1,706 @@ +from openvino.runtime import opset14 + +from keras.src.backend.common import dtypes +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import ov_to_keras_type + + +def _align_operand_types(x1, x2, op_name): + x1_type = x1.element_type + x2_type = x2.element_type + if x1_type.is_dynamic() or x2_type.is_dynamic(): + raise ValueError( + f"'{op_name}' operation is not supported for dynamic operand type " + "with openvino backend" + ) + x1_type = ov_to_keras_type(x1_type) + x2_type = ov_to_keras_type(x2_type) + result_type = dtypes.result_type(x1_type, x2_type) + result_type = OPENVINO_DTYPES[result_type] + if x1_type != result_type: + x1 = opset14.convert(x1, result_type) + if x2_type != result_type: + x2 = opset14.convert(x2, result_type) + return x1, x2 + + +def add(x1, x2): + x1, x2 = _align_operand_types(x1, x2, "add()") + return opset14.add(x1, x2) + + +def einsum(subscripts, *operands, **kwargs): + raise NotImplementedError("`einsum` is not supported with openvino backend") + + +def subtract(x1, x2): + x1, x2 = _align_operand_types(x1, x2, "subtract()") + return opset14.subtract(x1, x2) + + +def matmul(x1, x2): + raise NotImplementedError("`matmul` is not supported with openvino backend") + + +def multiply(x1, x2): + x1, x2 = _align_operand_types(x1, x2, "multiply()") + return opset14.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + raise NotImplementedError("`mean` is not supported with openvino backend") + + +def max(x, axis=None, keepdims=False, initial=None): + raise NotImplementedError("`max` is not supported with openvino backend") + + +def ones(shape, dtype=None): + raise NotImplementedError("`ones` is not supported with openvino backend") + + +def zeros(shape, dtype=None): + raise NotImplementedError("`zeros` is not supported with openvino backend") + + +def absolute(x): + return opset14.absolute(x) + + +def abs(x): + return opset14.absolute(x) + + +def all(x, axis=None, keepdims=False): + raise NotImplementedError("`all` is not supported with openvino backend") + + +def any(x, axis=None, keepdims=False): + raise NotImplementedError("`any` is not supported with openvino backend") + + +def amax(x, axis=None, keepdims=False): + raise NotImplementedError("`amax` is not supported with openvino backend") + + +def amin(x, axis=None, keepdims=False): + raise NotImplementedError("`amin` is not supported with openvino backend") + + +def append(x1, x2, axis=None): + raise NotImplementedError("`append` is not supported with openvino backend") + + +def arange(start, stop=None, step=None, dtype=None): + raise NotImplementedError("`arange` is not supported with openvino backend") + + +def arccos(x): + raise NotImplementedError("`arccos` is not supported with openvino backend") + + +def arccosh(x): + raise NotImplementedError( + "`arccosh` is not supported with openvino backend" + ) + + +def arcsin(x): + raise NotImplementedError("`arcsin` is not supported with openvino backend") + + +def arcsinh(x): + raise NotImplementedError( + "`arcsinh` is not supported with openvino backend" + ) + + +def arctan(x): + raise NotImplementedError("`arctan` is not supported with openvino backend") + + +def arctan2(x1, x2): + raise NotImplementedError( + "`arctan2` is not supported with openvino backend" + ) + + +def arctanh(x): + raise NotImplementedError( + "`arctanh` is not supported with openvino backend" + ) + + +def argmax(x, axis=None, keepdims=False): + raise NotImplementedError("`argmax` is not supported with openvino backend") + + +def argmin(x, axis=None, keepdims=False): + raise NotImplementedError("`argmin` is not supported with openvino backend") + + +def argsort(x, axis=-1): + raise NotImplementedError( + "`argsort` is not supported with openvino backend" + ) + + +def array(x, dtype=None): + return convert_to_tensor(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + raise NotImplementedError( + "`average` is not supported with openvino backend" + ) + + +def bincount(x, weights=None, minlength=0, sparse=False): + raise NotImplementedError( + "`bincount` is not supported with openvino backend" + ) + + +def broadcast_to(x, shape): + raise NotImplementedError( + "`broadcast_to` is not supported with openvino backend" + ) + + +def ceil(x): + raise NotImplementedError("`ceil` is not supported with openvino backend") + + +def clip(x, x_min, x_max): + raise NotImplementedError("`clip` is not supported with openvino backend") + + +def concatenate(xs, axis=0): + raise NotImplementedError( + "`concatenate` is not supported with openvino backend" + ) + + +def conjugate(x): + raise NotImplementedError( + "`conjugate` is not supported with openvino backend" + ) + + +def conj(x): + raise NotImplementedError("`conj` is not supported with openvino backend") + + +def copy(x): + raise NotImplementedError("`copy` is not supported with openvino backend") + + +def cos(x): + raise NotImplementedError("`cos` is not supported with openvino backend") + + +def cosh(x): + raise NotImplementedError("`cosh` is not supported with openvino backend") + + +def count_nonzero(x, axis=None): + raise NotImplementedError( + "`count_nonzero` is not supported with openvino backend" + ) + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + raise NotImplementedError("`cross` is not supported with openvino backend") + + +def cumprod(x, axis=None, dtype=None): + raise NotImplementedError( + "`cumprod` is not supported with openvino backend" + ) + + +def cumsum(x, axis=None, dtype=None): + raise NotImplementedError("`cumsum` is not supported with openvino backend") + + +def diag(x, k=0): + raise NotImplementedError("`diag` is not supported with openvino backend") + + +def diagonal(x, offset=0, axis1=0, axis2=1): + raise NotImplementedError( + "`diagonal` is not supported with openvino backend" + ) + + +def diff(a, n=1, axis=-1): + raise NotImplementedError("`diff` is not supported with openvino backend") + + +def digitize(x, bins): + raise NotImplementedError( + "`digitize` is not supported with openvino backend" + ) + + +def dot(x, y): + raise NotImplementedError("`dot` is not supported with openvino backend") + + +def empty(shape, dtype=None): + raise NotImplementedError("`empty` is not supported with openvino backend") + + +def equal(x1, x2): + raise NotImplementedError("`equal` is not supported with openvino backend") + + +def exp(x): + raise NotImplementedError("`exp` is not supported with openvino backend") + + +def expand_dims(x, axis): + raise NotImplementedError( + "`expand_dims` is not supported with openvino backend" + ) + + +def expm1(x): + raise NotImplementedError("`expm1` is not supported with openvino backend") + + +def flip(x, axis=None): + raise NotImplementedError("`flip` is not supported with openvino backend") + + +def floor(x): + raise NotImplementedError("`floor` is not supported with openvino backend") + + +def full(shape, fill_value, dtype=None): + raise NotImplementedError("`full` is not supported with openvino backend") + + +def full_like(x, fill_value, dtype=None): + raise NotImplementedError( + "`full_like` is not supported with openvino backend" + ) + + +def greater(x1, x2): + raise NotImplementedError( + "`greater` is not supported with openvino backend" + ) + + +def greater_equal(x1, x2): + raise NotImplementedError( + "`greater_equal` is not supported with openvino backend" + ) + + +def hstack(xs): + raise NotImplementedError("`hstack` is not supported with openvino backend") + + +def identity(n, dtype=None): + raise NotImplementedError( + "`identity` is not supported with openvino backend" + ) + + +def imag(x): + raise NotImplementedError("`imag` is not supported with openvino backend") + + +def isclose(x1, x2): + raise NotImplementedError( + "`isclose` is not supported with openvino backend" + ) + + +def isfinite(x): + raise NotImplementedError( + "`isfinite` is not supported with openvino backend" + ) + + +def isinf(x): + raise NotImplementedError("`isinf` is not supported with openvino backend") + + +def isnan(x): + raise NotImplementedError("`isnan` is not supported with openvino backend") + + +def less(x1, x2): + raise NotImplementedError("`less` is not supported with openvino backend") + + +def less_equal(x1, x2): + raise NotImplementedError( + "`less_equal` is not supported with openvino backend" + ) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + raise NotImplementedError( + "`linspace` is not supported with openvino backend" + ) + + +def log(x): + raise NotImplementedError("`log` is not supported with openvino backend") + + +def log10(x): + raise NotImplementedError("`log10` is not supported with openvino backend") + + +def log1p(x): + raise NotImplementedError("`log1p` is not supported with openvino backend") + + +def log2(x): + raise NotImplementedError("`log2` is not supported with openvino backend") + + +def logaddexp(x1, x2): + raise NotImplementedError( + "`logaddexp` is not supported with openvino backend" + ) + + +def logical_and(x1, x2): + raise NotImplementedError( + "`logical_and` is not supported with openvino backend" + ) + + +def logical_not(x): + raise NotImplementedError( + "`logical_not` is not supported with openvino backend" + ) + + +def logical_or(x1, x2): + raise NotImplementedError( + "`logical_or` is not supported with openvino backend" + ) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + raise NotImplementedError( + "`logspace` is not supported with openvino backend" + ) + + +def maximum(x1, x2): + raise NotImplementedError( + "`maximum` is not supported with openvino backend" + ) + + +def median(x, axis=None, keepdims=False): + raise NotImplementedError("`median` is not supported with openvino backend") + + +def meshgrid(*x, indexing="xy"): + raise NotImplementedError( + "`meshgrid` is not supported with openvino backend" + ) + + +def min(x, axis=None, keepdims=False, initial=None): + raise NotImplementedError("`min` is not supported with openvino backend") + + +def minimum(x1, x2): + raise NotImplementedError( + "`minimum` is not supported with openvino backend" + ) + + +def mod(x1, x2): + raise NotImplementedError("`mod` is not supported with openvino backend") + + +def moveaxis(x, source, destination): + raise NotImplementedError( + "`moveaxis` is not supported with openvino backend" + ) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + raise NotImplementedError( + "`nan_to_num` is not supported with openvino backend" + ) + + +def ndim(x): + raise NotImplementedError("`ndim` is not supported with openvino backend") + + +def nonzero(x): + raise NotImplementedError( + "`nonzero` is not supported with openvino backend" + ) + + +def not_equal(x1, x2): + raise NotImplementedError( + "`not_equal` is not supported with openvino backend" + ) + + +def zeros_like(x, dtype=None): + raise NotImplementedError( + "`zeros_like` is not supported with openvino backend" + ) + + +def ones_like(x, dtype=None): + raise NotImplementedError( + "`ones_like` is not supported with openvino backend" + ) + + +def outer(x1, x2): + raise NotImplementedError("`outer` is not supported with openvino backend") + + +def pad(x, pad_width, mode="constant", constant_values=None): + raise NotImplementedError("`pad` is not supported with openvino backend") + + +def prod(x, axis=None, keepdims=False, dtype=None): + raise NotImplementedError("`prod` is not supported with openvino backend") + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + raise NotImplementedError( + "`quantile` is not supported with openvino backend" + ) + + +def ravel(x): + raise NotImplementedError("`ravel` is not supported with openvino backend") + + +def real(x): + raise NotImplementedError("`real` is not supported with openvino backend") + + +def reciprocal(x): + raise NotImplementedError( + "`reciprocal` is not supported with openvino backend" + ) + + +def repeat(x, repeats, axis=None): + raise NotImplementedError("`repeat` is not supported with openvino backend") + + +def reshape(x, newshape): + raise NotImplementedError( + "`reshape` is not supported with openvino backend" + ) + + +def roll(x, shift, axis=None): + raise NotImplementedError("`roll` is not supported with openvino backend") + + +def sign(x): + raise NotImplementedError("`sign` is not supported with openvino backend") + + +def sin(x): + raise NotImplementedError("`sin` is not supported with openvino backend") + + +def sinh(x): + raise NotImplementedError("`sinh` is not supported with openvino backend") + + +def size(x): + raise NotImplementedError("`size` is not supported with openvino backend") + + +def sort(x, axis=-1): + raise NotImplementedError("`sort` is not supported with openvino backend") + + +def split(x, indices_or_sections, axis=0): + raise NotImplementedError("`split` is not supported with openvino backend") + + +def stack(x, axis=0): + raise NotImplementedError("`stack` is not supported with openvino backend") + + +def std(x, axis=None, keepdims=False): + raise NotImplementedError("`std` is not supported with openvino backend") + + +def swapaxes(x, axis1, axis2): + raise NotImplementedError( + "`swapaxes` is not supported with openvino backend" + ) + + +def take(x, indices, axis=None): + raise NotImplementedError("`take` is not supported with openvino backend") + + +def take_along_axis(x, indices, axis=None): + raise NotImplementedError( + "`take_along_axis` is not supported with openvino backend" + ) + + +def tan(x): + raise NotImplementedError("`tan` is not supported with openvino backend") + + +def tanh(x): + raise NotImplementedError("`tanh` is not supported with openvino backend") + + +def tensordot(x1, x2, axes=2): + raise NotImplementedError( + "`tensordot` is not supported with openvino backend" + ) + + +def round(x, decimals=0): + raise NotImplementedError("`round` is not supported with openvino backend") + + +def tile(x, repeats): + raise NotImplementedError("`tile` is not supported with openvino backend") + + +def trace(x, offset=0, axis1=0, axis2=1): + raise NotImplementedError("`trace` is not supported with openvino backend") + + +def tri(N, M=None, k=0, dtype=None): + raise NotImplementedError("`tri` is not supported with openvino backend") + + +def tril(x, k=0): + raise NotImplementedError("`tril` is not supported with openvino backend") + + +def triu(x, k=0): + raise NotImplementedError("`triu` is not supported with openvino backend") + + +def vdot(x1, x2): + raise NotImplementedError("`vdot` is not supported with openvino backend") + + +def vstack(xs): + raise NotImplementedError("`vstack` is not supported with openvino backend") + + +def vectorize(pyfunc, *, excluded=None, signature=None): + raise NotImplementedError( + "`vectorize` is not supported with openvino backend" + ) + + +def where(condition, x1, x2): + raise NotImplementedError("`where` is not supported with openvino backend") + + +def divide(x1, x2): + x1, x2 = _align_operand_types(x1, x2) + return opset14.divide(x1, x2) + + +def divide_no_nan(x1, x2): + raise NotImplementedError( + "`divide_no_nan` is not supported with openvino backend" + ) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + x1, x2 = _align_operand_types(x1, x2) + return opset14.power(x1, x2) + + +def negative(x): + return opset14.negative(x) + + +def square(x): + raise NotImplementedError("`square` is not supported with openvino backend") + + +def sqrt(x): + raise NotImplementedError("`sqrt` is not supported with openvino backend") + + +def squeeze(x, axis=None): + raise NotImplementedError( + "`squeeze` is not supported with openvino backend" + ) + + +def transpose(x, axes=None): + raise NotImplementedError( + "`transpose` is not supported with openvino backend" + ) + + +def var(x, axis=None, keepdims=False): + raise NotImplementedError("`var` is not supported with openvino backend") + + +def sum(x, axis=None, keepdims=False): + raise NotImplementedError("`sum` is not supported with openvino backend") + + +def eye(N, M=None, k=0, dtype=None): + raise NotImplementedError("`eye` is not supported with openvino backend") + + +def floor_divide(x1, x2): + raise NotImplementedError( + "`floor_divide` is not supported with openvino backend" + ) + + +def logical_xor(x1, x2): + return opset14.logical_xor(x1, x2) + + +def correlate(x1, x2, mode="valid"): + raise NotImplementedError( + "`correlate` is not supported with openvino backend" + ) + + +def select(condlist, choicelist, default=0): + raise NotImplementedError("`select` is not supported with openvino backend") + + +def slogdet(x): + raise NotImplementedError( + "`slogdet` is not supported with openvino backend" + ) + + +def argpartition(x, kth, axis=-1): + raise NotImplementedError( + "`argpartition` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py new file mode 100644 index 00000000000..4d6565d6f58 --- /dev/null +++ b/keras/src/backend/openvino/random.py @@ -0,0 +1,58 @@ +from keras.src.backend.config import floatx +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + raise NotImplementedError("`normal` is not supported with openvino backend") + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + raise NotImplementedError( + "`uniform` is not supported with openvino backend" + ) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + raise NotImplementedError( + "`categorical` is not supported with openvino backend" + ) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + raise NotImplementedError( + "`randint` is not supported with openvino backend" + ) + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + raise NotImplementedError( + "`truncated_normal` is not supported with openvino backend" + ) + + +def dropout(inputs, rate, noise_shape=None, seed=None): + raise NotImplementedError( + "`dropout` is not supported with openvino backend" + ) + + +def shuffle(x, axis=0, seed=None): + raise NotImplementedError( + "`shuffle` is not supported with openvino backend" + ) + + +def gamma(shape, alpha, dtype=None, seed=None): + raise NotImplementedError("`gamma` is not supported with openvino backend") + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + raise NotImplementedError( + "`binomial` is not supported with openvino backend" + ) + + +def beta(shape, alpha, beta, dtype=None, seed=None): + raise NotImplementedError("`beta` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py new file mode 100644 index 00000000000..70190fc47c8 --- /dev/null +++ b/keras/src/backend/openvino/rnn.py @@ -0,0 +1,38 @@ +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + raise NotImplementedError("`rnn` is not supported with openvino backend") + + +def lstm(*args, **kwargs): + raise NotImplementedError("`lstm` is not supported with openvino backend") + + +def gru(*args, **kwargs): + raise NotImplementedError("`gru` is not supported with openvino backend") + + +def unstack(x, axis=0): + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) + + +def numpy_scan(f, init, xs, reverse=False, mask=None): + raise NotImplementedError( + "`numpy_scan` is not supported with openvino backend" + ) + + +def cudnn_ok(*args, **kwargs): + return False diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py new file mode 100644 index 00000000000..364770e0feb --- /dev/null +++ b/keras/src/backend/openvino/trainer.py @@ -0,0 +1,243 @@ +import numpy as np +import openvino as ov +import openvino.runtime.opset14 as ov_opset + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import tree +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import get_device +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class OpenVINOTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.test_function = None + self.predict_function = None + self.ov_compiled_model = None + self.ov_device = None + + def _unpack_singleton(self, x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x + + def test_step(self, data): + raise NotImplementedError( + "`test_step` is not supported with openvino backend" + ) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + ov_compiled_model = self._get_compiled_model() + flatten_x = tree.flatten(x) + y_pred = ov_compiled_model(flatten_x) + y_pred = self._unpack_singleton( + tree.pack_sequence_as(self._outputs_struct, y_pred.to_tuple()) + ) + return y_pred + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + def one_test_step(data): + data = data[0] + return self.test_step(data) + + def multi_test_steps(data): + for single_step_data in data: + logs = one_test_step([single_step_data]) + return logs + + if self.steps_per_execution > 1: + test_step = multi_test_steps + else: + test_step = one_test_step + + self.test_function = test_step + + def _get_compiled_model(self): + if ( + self.ov_compiled_model is not None + and get_device() == self.ov_device + ): + return self.ov_compiled_model + + # prepare compiled model from scratch + del self.ov_compiled_model + ov_inputs = [] + for _input in self._inputs: + ov_type = OPENVINO_DTYPES[_input.dtype] + ov_shape = _input.shape + ov_shape = list(ov_shape) + for i in range(len(ov_shape)): + if ov_shape[i] is None: + ov_shape[i] = -1 + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + ov_inputs.append(param) + # build OpenVINO graph ov.Model + ov_outputs = self._run_through_graph( + ov_inputs, operation_fn=lambda op: op + ) + ov_outputs = tree.flatten(ov_outputs) + ov_model = ov.Model(results=ov_outputs, parameters=ov_inputs) + self.ov_compiled_model = ov.compile_model(ov_model, get_device()) + self.ov_device = get_device() + return self.ov_compiled_model + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def one_predict_step(data): + data = data[0] + return self.predict_step(data) + + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, + ) + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + raise NotImplementedError( + "`fit` is not supported with openvino backend" + ) + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + for step, data in epoch_iterator.enumerate_epoch(): + callbacks.on_predict_batch_begin(step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + if self.stop_predicting: + break + callbacks.on_predict_end() + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + raise NotImplementedError( + "`evaluate` is not supported with openvino backend" + ) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "`train_on_batch` is not supported with openvino backend" + ) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "`test_on_batch` is not supported with openvino backend" + ) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index d065fdd2fdf..3756b79e093 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -52,6 +52,8 @@ from keras.src.backend.torch.layer import TorchLayer as BackendLayer elif backend.backend() == "numpy": from keras.src.backend.numpy.layer import NumpyLayer as BackendLayer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.layer import OpenvinoLayer as BackendLayer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement a layer mixin class." @@ -311,6 +313,9 @@ def __init__( self._convert_input_args = True # Whether to allow non-tensors as positional arguments in `call()`. self._allow_non_tensor_positional_args = False + if backend.backend() == "openvino": + self._allow_non_tensor_positional_args = True + # Dict of shapes that were used to call `build()`. self._build_shapes_dict = None # Parent path diff --git a/keras/src/models/model.py b/keras/src/models/model.py index e03e6dc97bd..e3bcf9d162f 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -23,6 +23,8 @@ from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "numpy": from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer else: raise RuntimeError( f"Backend '{backend.backend()}' must implement the Trainer class." diff --git a/keras/src/utils/backend_utils.py b/keras/src/utils/backend_utils.py index 3e974aa3e6e..05e637d444b 100644 --- a/keras/src/utils/backend_utils.py +++ b/keras/src/utils/backend_utils.py @@ -60,10 +60,10 @@ def __init__(self, backend=None): self._backend = backend or backend_module.backend() def set_backend(self, backend): - if backend not in ("tensorflow", "jax", "torch", "numpy"): + if backend not in ("tensorflow", "jax", "torch", "numpy", "openvino"): raise ValueError( - "Available backends are ('tensorflow', 'jax', 'torch' and " - f"'numpy'). Received: backend={backend}" + "Available backends are ('tensorflow', 'jax', 'torch', " + f"'numpy' and 'openvino'). Received: backend={backend}" ) self._backend = backend @@ -92,6 +92,9 @@ def __getattr__(self, name): "Currently, we cannot dynamically import the numpy backend " "because it would disrupt the namespace of the import." ) + if self._backend == "openvino": + module = importlib.import_module("keras.src.backend.openvino") + return getattr(module, name) @keras_export("keras.config.set_backend") diff --git a/requirements.txt b/requirements.txt index 14ba558acea..3f83fd5e06b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,8 @@ torchvision>=0.16.0 jax[cpu] flax +# OpenVINO +openvino~=2024.4.0 + # Common deps. -r requirements-common.txt