RatInABox2.0 - Roadmap #84
Replies: 11 comments 3 replies
-
args, not dicts 👍A helpful case study in support of args ... Most of you (I'm sure) have seen Grant Sanderson's beautiful 3blue1brown YouTube channel. Grant impressively homebrewed the Similar to here, Grant Sanderson's fork is also trying to remove them: 3b1b/manim#1932 plotting 👍💯 replotting = slow. ... if ratinabox caches plot objects, super recommend scheme we chatted about: The TaskEnvironment has a weak version of this feature -- doesn't replot everything and thus renders quickly. But it's pretty hacky in my view that the environment caches things about its agents and goals. In the long-run, it will be more maintainable to have each class in charge of caching its own plot objects rather than having to change master supervisor class's plot every time the children classes change. type hinting 👍Especially easy-to-type variables. Tools like unit testing 👍global environment 👍Possible suggestion: each RIB class could have a list of children ( Jax 🤷♂️No strong opinions. Leaning partial Jax if the penalty for binary-op/shuttling numpy to a CPU jax.device is low. |
Beta Was this translation helpful? Give feedback.
-
Sounds like a great idea overall for the longevity of the package! I definitely agree for the args instead of dicts, type hinting and unit testing. For global environment, if the cascading update is implemented, I would suggest having a I would suggest an additional section: |
Beta Was this translation helpful? Give feedback.
-
Great comments, thanks guys. @SynapticSage 3B1B advice heeded! @colleenjg you're right this could be more modular, for example |
Beta Was this translation helpful? Give feedback.
-
These all sound like great changes for RAIB 2.0, and I agree w all of the comments from @SynapticSage and @colleenjg :) I'm a particularly big fan of the global environment updating, as this seems much more concise. My only concern is whether this would slow down updates for really long simulations (like the ones I have been running, e.g. @ 30 Hz x 31 sessions x 40 min/session). It might be ideal to perform more selective updates and skip others if they are going to be static using some sort of argument in As far as Jax compatibility, I would be very much for this if it can actually speed things up for the heavier computations and long simulations, but as you point out it might not save compute time if large arrays are being converted often. I believe it would be worth some case testing in a couple of large simulations before ruling this out. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the feedback, closing for now. |
Beta Was this translation helpful? Give feedback.
-
One thing that just occurred to me, which could be considered: Only passing In typical use cases, to my knowledge, passing both should be redundant, as you can access the figure with |
Beta Was this translation helpful? Give feedback.
-
Agreed and added to the list. It's essentially redundant and only add bloat |
Beta Was this translation helpful? Give feedback.
-
If you add As the primary maintainer of the Fedora Linux package for this project, I’m not sure if packaging https://github.com/google/jax would be feasible for us or not. While it does look like |
Beta Was this translation helpful? Give feedback.
-
@musicinmybrain thanks for your feedback - that's ok, I doubt we'd go full |
Beta Was this translation helpful? Give feedback.
-
We should consider type hinting for Riab 2.0. (https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html) This will enable users to do easier lookups and help with the autocompletion of code for anyone coding in an integrated development environment. Will mostly involve 2 things:-
|
Beta Was this translation helpful? Give feedback.
-
@mehulrastogi just thinking...do you think a more scalable way to support dynamic environments would be for all environmental objects to be their own class. I can imagine a This is as opposed to the solution I'd previously imagined where the environment itself store a "state" dictionary which can be updated. The nice thing about this new proposal is that My main concern is that it would create memory issues storing so much more data but to be honest I'm not sure it would be much more data than Neurons already save without problems. We could also image some memory-clever solution since 99% of the time walls won't move. |
Beta Was this translation helpful? Give feedback.
-
I've begun to think about 2.0. The reason is that there are are certainly a couple of choices I made early on in development which weren't optimal. Now could be a good time to fix these as the community is growing but still small enough it won't be super disruptive. Also fixing them will make it easier to maintain RiaB in the long run.
I'm opening this issue to get community thoughts on this. @SynapticSage @colleenjg @jquinnlee @mehulrastogi you're some of the most active users I know fairly well so I'm tagging you to get your input (if you have any), but anyone can chip in here. Here's my thoughts:
Essential and backwards incompatible changes (do first):
Neurons
classes in one.py
file.update()
: Given, now,Environments
know about theirAgents
andAgents
know about theirNeurons
we could have just one update function inEnv
which cascades through else thing else. Cleaner?dev
-->main
Environment
stores the global clock. This just makes sense imo.drift_velocity
kwarg. Maybe insteadAgent
s can have apolicy()
method which returns a drift - this would default to the random motion policy, unifying that too. Just something to consider.get_state()
repeats logic each time. Seems better to have a unique.forward()
(or maybe calledcalculate_firing_rate()
) method for each class which receives arrays of positions and head directions etc and a shared.get_state()
which lives inNeurons
which calls.forward()
. I'm returning to +1 this idea, it makes a lot more sense. Also, instead ofget_state()
we can have numerous such asget_agent_firing_rate()
,get_rate_map()
,get_angular_tuning_curve()
etc. These "get"-functions should return not just lists of the firing rates but also the lists of the respective coordinates (we could maybe usexarray
for this but I don't like the extra dependency).EnvironmentEntity
class could be made which can be added to anEnvironment
which will then update each entity at each time step. Entities will have their own.render
methods so Env can loop over them and plot them too. This could be used to flexibly buildTeleports
,Doors
,Keys
, etc. etc. for more dynamic environments.Other essential changes
update()
perhaps adding into new agent/neuron/env specific utils scripts.Env.history
dictionary. Then, when plotting / animating the environment we can pass in a time argument and the correct state can be retrieved and plotted. The state of the environment only appends to history whenever it changes (e.g. a setter is called).plot_environment()
it can be passed afig
anax
and a new object which is a list/dict of plot objects,R
which are allmatplotlib.Artists
already existing on the figure. The environment can store an equivalent list of plot objects and whenever this changes (e.g. a wall is added or an object is moved etc.) this change is logged then plotting can (i) get the list of plot objects corresponding to the correct time and (ii) compared it to the passed list, if they aren't equal then repot the env, otherwise don't bother. Something like that.Environment
s have anEnv.history
dictionary storing the full "state" of the environment (all object locations, walls, boundaries, etc.). ThenEnv.plot_environment()
takes a time argument and find the state of the at that time and plots that.Agent.update()
as shown in the paper this is a significant bottleneck.animate_
API withplot_
API just with a few extra kwargs.ax
notfig
to figure plotting functions. This may throw up some things but likely minor.utils.py
into separate ones for theAgent
package,Neurons
package andEnv
package and maybe also amisc
.RatInABox/RatInABox
notTomGeorge1234/RatInABox
RatInABox/RatInABox_RL
** package containing all the RL stuff (Actor
,Critic
,ValueNeuron
,TDError
,TaskEnv
etc.)IntermediateNeurons
subclass for neurons which aren't "fundamental" but take other neurons as inputs. Current examples areFeedForwardLayer
andNeuralNetworkNeurons
DynamicNeurons
subclass for neurons which aren't static i.e. you can't callNeurons.plot_rate_map()
because they actually depend on the past history. Examples includeTDErrorNeurons
(to be made) or anything with recurrency.SmoothRandomFeatureNeurons
just some spatially tuned but random neurons. Users just provide a length scale. Would be useful for a lot of feature learning studies. Probably something like a gaussian process underlying these neurons.args
morekwargs
Some of the functions (in particular plotting) have quite bloated argument lists. I think it would be better to remove some of these and allow them to be hidden in**kwargs
then defined at the top of the functionarg = kwargs.get("arg_name",default_val)
. This is backwards compatible, cleans up the doc strings so more readable, and we can use this to expose any/all free parameters (even ones which weren't before hand an argument) for greater flexibility. I have done this to theAgent.plot_trajectory()
Things to consider
Neurons
should followtorch.nn.module
API - this would make more efficient the evaluation of complex feedforward graphs which currently happens in a backwards manner. This might require renaming the.get_state()
method with.forward()
. Need to think more about thisnp
-->jnp
everywhere.I'm not a software guy so @SynapticSage @mehulrastogi feel free to give high level comments about best way to go forward.
Beta Was this translation helpful? Give feedback.
All reactions