-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't control batch size when all dims are input dims #163
Comments
Alternatively, is it possible in this scenario to "rechunk" along the sample dimension (so you'd get like 32 x lon x lat)? |
So I figured out I can use |
A temporary solution can be to just create a size 1 expanded dimension. Example:
|
Yeah this is connected to some other general weirdness about the number of input vs. concat dims. I'll try adding dummy dimensions again and see what I get (but it would be nice not to have to hack around this). I am having some success with writing a generator wrapper like this (bad code alert!):
However, you get a different kind of slowdown from having to slice the batch generator at different points. I realized you don't have to use the NN model's batch size here, so it could be larger, and you could find a good compromise between time spent slicing the batch generator, and time retrieving the next batch in your training loop. |
Also wanted to note that this issue turns xbatcher into a massive memory hog, and it's probably related to #37 as well. |
Why is there a deep copy here? |
As I noted on the other thread, that is not a deep copy. It's a very shallow copy. Creating a copy of the data array avoids causing side effects to the user's inputs. |
Along the lines of #162 (comment), we can create fixed-size batches for the case of all dims being input dims by using a import xarray as xr
import numpy as np
import xbatcher as xb
da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})
def batch_generator(bgen, batch_size):
b = (batch for batch in bgen)
n = 0
while n < 400: # hardcoded n is a kludge; while-loop is necessary
batch_stack = [ next(b) for i in range(batch_size) ]
yield xr.concat(batch_stack, 'sample')
n += 1
bgen = xb.BatchGenerator(
ds,
{'d1':20, 'd2':20},
{'d1':2, 'd2':2}
)
gen = batch_generator(bgen, 32)
a = []
for batch in bgen:
a = batch
break
a |
Is your feature request related to a problem?
Title. Basically, in most cases you can control your batch size by setting the
batch_dims
option in BatchGenerator. However, if you don't have any batch dims to start with, you are effectively unable to control your batch size.e.g., for an xarray DataSet
ds
with dimslat
andlon
, a BatchGenerator likeoffers no option to control batch size.
Describe the solution you'd like
I want to be able to pass an integer to BatchGenerator that tells it the size of the batch I want, in the case described above.
Maybe something like this, but wrapped as a BatchGenerator option.
Describe alternatives you've considered
No response
Additional context
I think this can probably be solved at the same time as #127
The text was updated successfully, but these errors were encountered: