Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Tobi Dance #22

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
20df0c1
added logging.info to inform about not compiling
tobidelbruck Nov 26, 2022
9e7d4fa
to fix tensorflow JIT compile case insentivity problem, renamed cart …
tobidelbruck Nov 27, 2022
8c60986
generate pycharm link in log output.
tobidelbruck Nov 27, 2022
4c0bd92
added nan value to all libs
tobidelbruck Nov 28, 2022
1681fdd
rename config_cost_function.yml to config_cost_functions.yml for cons…
tobidelbruck Nov 28, 2022
4776ed4
finally the dynamically modifiable control cost parameters are workin…
tobidelbruck Dec 11, 2022
9e72b37
now spin and balance both work! and so does changing the policy and …
tobidelbruck Dec 12, 2022
3c1895f
Pull from master branch
frehe Dec 14, 2022
0c6ea2a
Rename tensorflow compilation flags
frehe Dec 15, 2022
8d5eddd
add equal and pow methods
tobidelbruck Dec 18, 2022
1ddc244
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Dec 18, 2022
d343f31
renamed s to state for clariy in many of the classes.
Dec 24, 2022
91a611e
Merge remote-tracking branch 'origin/main' into Tobi_Dance
tobidelbruck Jan 31, 2023
a6fda56
Merge remote-tracking branch 'origin/master' into Tobi_Dance
tobidelbruck Jan 31, 2023
7f7659d
moved get_logger to own file in SI_Toolkit
tobidelbruck Feb 6, 2023
4d34dde
update path to config_cost_functions.yml
tobidelbruck Feb 7, 2023
76d455d
Merge remote-tracking branch 'origin/Tobi_Dance' into Tobi_Dance
tobidelbruck Feb 7, 2023
f13330c
move get_logger.py to Control_Toolkit so that it can be used by physi…
tobidelbruck Feb 8, 2023
5155fe9
cartpole_dancer.py starts to work. Music starts and stops, some steps…
tobidelbruck Feb 10, 2023
f63ab96
added primitive ability to record the predictor_ODE_tf.py predictions…
tobidelbruck Feb 13, 2023
ab87fec
added 'cartwheel' step to cartpole_trajectory_generator.py.
tobidelbruck Feb 16, 2023
3dd5102
fixed some logic and reduced some loggers to debug level
tobidelbruck Feb 19, 2023
89f4ead
major changes to cartpole_dancer_cost and cartpole_trajectory_generat…
tobidelbruck Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions SI_Toolkit_ASF_Template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
GLOBALLY_DISABLE_COMPILATION = False # Set to False to use tf.function
USE_JIT_COMPILATION = True # XLA ignores random seeds. Set to False for reproducibility
### Choose whether to run TensorFlow in eager mode (slow, interpreted) or graph mode (fast, compiled)
# Set `USE_TENSORFLOW_EAGER_MODE=False` to...
# - decorate functions in optimizers and predictors with `@tf.function`.
# - and thereby enable TensorFlow graph mode. This is much faster than the standard eager mode.
USE_TENSORFLOW_EAGER_MODE = False


### Choose whether to use TensorFlow Accelerated Linear Algebra (XLA).
# XLA uses machine-specific conversions to speed up the compiled TensorFlow graph.
# Set USE_TENSORFLOW_XLA to True to accelerate the execution (for real-time).
# If `USE_TENSORFLOW_XLA=True`, this adds `jit_compile=True` to the `tf.function` decorator.
# However, XLA ignores random seeds. Set to False for guaranteed reproducibility, such as for simulations.
USE_TENSORFLOW_XLA = True
2 changes: 1 addition & 1 deletion src/SI_Toolkit/Functions/General/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_batch(self, idx_batch):
def reset_batch_size(self, batch_size=None):

if batch_size is None:
self.batch_size = self.args.batch_size
self.batch_size = self.args.num_rollouts
else:
self.batch_size = batch_size

Expand Down
2 changes: 1 addition & 1 deletion src/SI_Toolkit/Functions/Pytorch/Training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train_network_core(net, net_info, training_dfs_norm, validation_dfs_norm, te
del training_dfs_norm, validation_dfs_norm, test_dfs_norm

# Create PyTorch dataloaders for train and dev set
training_generator = data.DataLoader(dataset=training_dataset, batch_size=a.batch_size, shuffle=True)
training_generator = data.DataLoader(dataset=training_dataset, batch_size=a.num_rollouts, shuffle=True)
validation_generator = data.DataLoader(dataset=validation_dataset, batch_size=512, shuffle=False)

print('')
Expand Down
35 changes: 25 additions & 10 deletions src/SI_Toolkit/Functions/TF/Compile.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import logging
import platform

import tensorflow as tf
import torch

from Control_Toolkit.others.get_logger import get_logger
log=get_logger(__name__)

from SI_Toolkit.computation_library import ComputationLibrary



try:
from SI_Toolkit_ASF import GLOBALLY_DISABLE_COMPILATION, USE_JIT_COMPILATION
from SI_Toolkit_ASF import USE_TENSORFLOW_EAGER_MODE, USE_TENSORFLOW_XLA
except ImportError:
logging.warn("No compilation option set in SI_Toolkit_ASF. Setting GLOBALLY_DISABLE_COMPILATION to True.")
GLOBALLY_DISABLE_COMPILATION = True
raise Exception("Either/both of compilation options USE_TENSORFLOW_EAGER_MODE, USE_TENSORFLOW_XLA are missing in SI_Toolkit_ASF/__init.py__.")

def tf_function_jit(func):
return tf.function(func=func, jit_compile=True)
# log.debug(f'compiling tf.function from {func}')
return tf.function(func=func, jit_compile=True,)


def tf_function_experimental(func):
Expand All @@ -24,27 +27,39 @@ def identity(func):
return func


if GLOBALLY_DISABLE_COMPILATION:
if USE_TENSORFLOW_EAGER_MODE:
log.warning('TensorFlow compilation is disabled by USE_TENSORFLOW_EAGER_MODE=True and execution will be extremely slow')
CompileTF = identity
else:
if platform.machine() == 'arm64' and platform.system() == 'Darwin': # For M1 Apple processor
log.info('TensorFlow compilation (but not JIT) is enabled by tf.function by USE_TENSORFLOW_EAGER_MODE=False and USE_TENSORFLOW_XLA = False')
CompileTF = tf.function
elif not USE_JIT_COMPILATION:
elif not USE_TENSORFLOW_XLA:
log.info('TensorFlow compilation (but not JIT) is enabled by tf.function by USE_TENSORFLOW_EAGER_MODE=False and USE_TENSORFLOW_XLA = False')
CompileTF = tf.function
else:
log.info('TensorFlow compilation and JIT are both enabled by tf.function_jit by USE_TENSORFLOW_EAGER_MODE=False and USE_TENSORFLOW_XLA = True')
CompileTF = tf_function_jit
log.info(f'using {CompileTF} compilation')
# CompileTF = tf_function_experimental # Should be same as tf_function_jit, not appropriate for newer version of TF

def CompileAdaptive(fun):
"""
Compiles the function using options for TensorFlow and XLA JIT, according to global flags USE_TENSORFLOW_EAGER_MODE.

See SI_Toolkit_ASF\__init__.py

"""
instance = fun.__self__
assert hasattr(instance, "lib"), "Instance with this method has no computation library defined"
computation_library: "type[ComputationLibrary]" = instance.lib
lib_name = computation_library.lib

if GLOBALLY_DISABLE_COMPILATION:
if USE_TENSORFLOW_EAGER_MODE:
return identity(fun)
elif lib_name == 'TF':
log.debug(f'compiling tensorflow {fun}')
return CompileTF(fun)
else:
print('Jit compilation for Pytorch not yet implemented.')
log.warning(f'JIT compilation for {lib_name} not yet implemented.')
return identity(fun)
4 changes: 2 additions & 2 deletions src/SI_Toolkit/GP/DataSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def return_dataset_for_training(self,
raw=False
):

if batch_size is None and self.args.batch_size is not None:
batch_size = self.args.batch_size
if batch_size is None and self.args.num_rollouts is not None:
batch_size = self.args.num_rollouts

if inputs is None and self.args.inputs is not None:
inputs = self.args.inputs
Expand Down
4 changes: 2 additions & 2 deletions src/SI_Toolkit/GP/TimeGP.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def timing_script_init():
m_loaded = load_model(save_dir)
print("Done!")

num_rollouts = 2000
batch_size = 2000
horizon = 35

s = tf.zeros(shape=[num_rollouts, 6], dtype=tf.float64)
s = tf.zeros(shape=[batch_size, 6], dtype=tf.float64)
m_loaded.predict_f(s)

return m_loaded, s
Expand Down
2 changes: 1 addition & 1 deletion src/SI_Toolkit/GP/Train_GPR.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
a.wash_out_len = 0
a.post_wash_out_len = 1
outputs = a.outputs
batch_size = a.batch_size
batch_size = a.num_rollouts

number_of_inducing_points = 10

Expand Down
4 changes: 3 additions & 1 deletion src/SI_Toolkit/Predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def __init__(self, horizon: float, batch_size: int) -> None:
self.predictor_external_input_features = CONTROL_INPUTS
self.predictor_output_features = STATE_VARIABLES

def predict_tf(self, s: tf.Tensor, Q: tf.Tensor):
def predict_tf(self, s: tf.Tensor, Q: tf.Tensor, time:float=None):
"""Predict the whole MPC horizon using tensorflow

:param s: Initial state [batch_size x state_dim]
:type s: tf.Tensor
:param Q: Control inputs [batch_size x horizon_length x control_dim]
:type Q: tf.Tensor
:param time: time in seconds
:type time: float
"""
raise NotImplementedError()

Expand Down
21 changes: 16 additions & 5 deletions src/SI_Toolkit/Predictors/predictor_ODE_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,37 @@ def __init__(self, horizon: int, dt: float, intermediate_steps=10, disable_indiv
self.predict_tf = CompileTF(self._predict_tf)


def predict(self, initial_state, Q):
def predict(self, initial_state, Q, time:float=None, horizon:int=None):
initial_state, Q = convert_to_tensors(initial_state, Q)
initial_state, Q = check_dimensions(initial_state, Q)

self.batch_size = tf.shape(Q)[0]
self.initial_state = initial_state

output = self.predict_tf(self.initial_state, Q)
output = self.predict_tf(self.initial_state, Q, params=None, horizon=horizon)

return output.numpy()


def _predict_tf(self, initial_state, Q, params=None):
def _predict_tf(self, initial_state, Q, params=None, time:float=None, horizon:int=None):
""" Predict the states over horizon next timesteps.
Q must be a 3-dimensional vector [num_rollouts, horizon, Q] where Q is the vector of control inputs

self.output = tf.TensorArray(tf.float32, size=self.horizon + 1, dynamic_size=False)
:param initial_state: the state now
:param Q: the control over horizon next steps
:param params: optional parameters
:param time: the current time in seconds
:param horizon: optional horizon, if None then use self.horizon

:returns: the predicted states including as first component of horizon dimension the initial state, [num_rollouts, horizon+1, states]
"""
horizon=self.horizon if horizon is None else horizon
self.output = tf.TensorArray(tf.float32, size=horizon + 1, dynamic_size=False)
self.output = self.output.write(0, initial_state)

next_state = initial_state

for k in tf.range(self.horizon):
for k in tf.range(horizon):
next_state = self.next_step_predictor.step(next_state, Q[:, k, :], params)
self.output = self.output.write(k + 1, next_state)

Expand Down
22 changes: 14 additions & 8 deletions src/SI_Toolkit/Predictors/predictor_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class PredictorWrapper:
"""Wrapper class for creating a predictor.

1) Instantiate this wrapper without parameters within the controller class
2) Pass the instance of this wrapper to the optimizer, without the need to already know specifics about it
3) Call this wrapper's `configure` method in controller class to set optimization-specific parameters
Expand All @@ -33,9 +33,10 @@ def __init__(self):
self.predictor_type: str = self.predictor_config['predictor_type']
self.model_name: str = self.predictor_config['model_name']

def configure(self, batch_size: int, horizon: int, dt: float, computation_library: "Optional[type[ComputationLibrary]]"=None, predictor_specification=None, compile_standalone=False, mode=None):
def configure(self, batch_size: int, horizon: int, dt: float, computation_library: Optional[ComputationLibrary]=None, predictor_specification=None, compile_standalone=False, mode=None):
"""Assign optimization-specific parameters to finalize instance creation.


:param batch_size: Batch size equals the number of parallel rollouts of the optimizer.
:type batch_size: int
:param horizon: Number of MPC horizon steps
Expand Down Expand Up @@ -75,11 +76,16 @@ def configure(self, batch_size: int, horizon: int, dt: float, computation_librar
self.predictor = predictor_ODE_tf(horizon=self.horizon, dt=dt, batch_size=self.batch_size, **self.predictor_config, **compile_standalone)

else:
raise NotImplementedError('Type of the predictor not recognised.')
raise NotImplementedError(f'Type of the predictor {self.predictor_type} is not recognised.')

# computation_library defaults to None. In that case, do not check for conformity.
if computation_library is not None and computation_library not in self.predictor.supported_computation_libraries:
raise ValueError(f"Predictor {self.predictor.__class__.__name__} does not support {computation_library.__name__}")
# in other cases, check after we configure it to make sure it supports itself
if not computation_library is None and computation_library not in self.predictor.supported_computation_libraries:
raise ValueError(
f"Predictor {self.predictor.__class__.__name__} does not support {computation_library.__name__}")
Comment on lines +83 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not computation_library is None and computation_library not in self.predictor.supported_computation_libraries:
raise ValueError(
f"Predictor {self.predictor.__class__.__name__} does not support {computation_library.__name__}")
if computation_library is not None and computation_library not in self.predictor.supported_computation_libraries:
raise ValueError(
f"Predictor {self.predictor.__class__.__name__} does not support {computation_library.__name__}"
)


self.predictor.lib=computation_library # set the library type on the predictor object so we can use it to assign attributes later


def configure_with_compilation(self, batch_size, horizon, dt, predictor_specification=None, mode=None):
"""
Expand Down Expand Up @@ -155,8 +161,8 @@ def update_predictor_config_from_specification(self, predictor_specification: st
def predict(self, s, Q):
return self.predictor.predict(s, Q)

def predict_tf(self, s, Q): # TODO: This function should disappear: predict() should manage the right library
return self.predictor.predict_tf(s, Q)
def predict_tf(self, state, Q, time=None): # TODO: This function should disappear: predict() should manage the right library
return self.predictor.predict_tf(state, Q, time=time)

def update(self, Q0, s):
if self.predictor_type == 'neural':
Expand Down
35 changes: 30 additions & 5 deletions src/SI_Toolkit/computation_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ComputationLibrary:
gather: Callable[[TensorType, TensorType, int], TensorType] = None
gather_last: Callable[[TensorType, TensorType], TensorType] = None
arange: Callable[[Optional[NumericType], NumericType, Optional[NumericType]], TensorType] = None
zeros: Callable[["tuple[int]"], TensorType] = None
zeros: Callable[["tuple[int,...]"], TensorType] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I had wrong assumptions about how to type subscript tuples. Learned something again.

zeros_like: Callable[[TensorType], TensorType] = None
ones: Callable[["tuple[int]"], TensorType] = None
ones_like: Callable[[TensorType], TensorType] = None
Expand Down Expand Up @@ -92,9 +92,16 @@ class ComputationLibrary:
dot: Callable[[TensorType, TensorType], TensorType] = None
stop_gradient: Callable[[TensorType], TensorType] = None
assign: Callable[[Union[TensorType, tf.Variable], TensorType], Union[TensorType, tf.Variable]] = None
nan:TensorType=None
isnan:Callable[[TensorType],bool]=None
string = None
equal= lambda x,y: x==y
pow=lambda x,p: x**p
where: Callable[[TensorType, TensorType, TensorType], TensorType] = None
logical_and: Callable[[TensorType, TensorType], TensorType] = None
logical_or: Callable[[TensorType, TensorType], TensorType] = None
dtype=lambda x: x.dtype
fill = None


class NumpyLibrary(ComputationLibrary):
Expand Down Expand Up @@ -169,19 +176,24 @@ class NumpyLibrary(ComputationLibrary):
dot = np.dot
stop_gradient = lambda x: x
assign = LibraryHelperFunctions.set_to_value
nan = np.nan
isnan=np.isnan
string=str
where = np.where
logical_and = np.logical_and
logical_or = np.logical_or


equal= lambda x,y: x==y
cond= lambda cond, t, f: t if cond else f
pow=lambda x,p: np.power(x,p)
fill = lambda x,y: x.np.fill(y)

class TensorFlowLibrary(ComputationLibrary):
lib = 'TF'
reshape = tf.reshape
permute = tf.transpose
newaxis = tf.newaxis
shape = lambda x: x.get_shape() # .as_list()
to_numpy = lambda x: x.numpy()
shape = tf.shape # tobi does not understand reason for this previous definition: # lambda x: x.get_shape() # .as_list()
Copy link
Member Author

@frehe frehe Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No difference: If x is a tf.Tensor, then x.shape and x.get_shape() are identical https://www.tensorflow.org/api_docs/python/tf/Tensor#get_shape

Using tf.shape(x) seems to be the best choice here, which is a third option.
From https://www.tensorflow.org/api_docs/python/tf/shape:
tf.shape and Tensor.shape should be identical in eager mode. Within tf.function or within a compat.v1 context, not all dimensions may be known until execution time. Hence, when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape.

to_numpy = lambda x: x.numpy() if isinstance(x,(tf.Tensor, tf.Variable)) else x
to_variable = lambda x, dtype: tf.Variable(x, dtype=dtype)
to_tensor = lambda x, dtype: tf.convert_to_tensor(x, dtype=dtype)
constant = lambda x, t: tf.constant(x, dtype=t)
Expand Down Expand Up @@ -247,9 +259,16 @@ class TensorFlowLibrary(ComputationLibrary):
dot = lambda a, b: tf.tensordot(a, b, 1)
stop_gradient = tf.stop_gradient
assign = LibraryHelperFunctions.set_to_variable
nan=tf.constant(np.nan)
isnan=tf.math.is_nan
string=tf.string
where = tf.where
logical_and = tf.math.logical_and
logical_or = tf.math.logical_or
equal= lambda x,y: tf.math.equal(x,y)
cond= lambda cond, t, f: tf.cond(cond,t,f)
pow=lambda x,p: tf.pow(x,p)
fill = lambda dims,value: tf.fill(dims,value)

class PyTorchLibrary(ComputationLibrary):

Expand Down Expand Up @@ -332,6 +351,12 @@ def gather_last_pytorch(a, index_vector):
dot = torch.dot
stop_gradient = tf.stop_gradient # FIXME: How to imlement this in torch?
assign = LibraryHelperFunctions.set_to_value
nan=torch.nan
isnan=torch.isnan
string=lambda x: torch.ByteTensor(bytes(x,'utf8'))
where = torch.where
logical_and = torch.logical_and
logical_or = torch.logical_or
equal=lambda x,y: torch.equal(x,y)
pow=lambda x,p: torch.pow(x,p)
fill = lambda x,y: x.torch.Tensor.fill(x,y)