-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data adapter: Broadcasting causes issue with sampling #258
Comments
An you show the simulator and adapter you used? I will then check what we
can do about it.
Malte Lüken ***@***.***> schrieb am Fr., 22. Nov. 2024, 14:55:
… When using a Adapter with a Broadcast transform where
to="inference_variables", training runs normally, but when I try sampling
it complains that it cannot find the key "inference_variables" which I
did not include from the conditions dictionary following the example
notebooks.
My guess is the Broadcast transform also needs a strict argument in the
forward method for sampling as in the FilterTransform class?
—
Reply to this email directly, view it on GitHub
<#258>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADCW2AHQIYHVTPJL7TCIGFL2B4ZVJAVCNFSM6AAAAABSJNAJ36VHI2DSMVQWIX3LMV43ASLTON2WKOZSGY4DGMRZGAZTGOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Sure! Here is my simulator:
And here is my adapter:
|
thanks!
Malte Lüken ***@***.***> schrieb am Fr., 22. Nov. 2024, 15:41:
… Sure! Here is my simulator:
class CustomSimulator(bf.simulators.Simulator):
def __init__(self, prior_fun: Callable, design_fun: Callable, simulator_fun: Callable):
self.prior_fun = prior_fun
self.design_fun = design_fun
self.simulator_fun = simulator_fun
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
prior_dict = self.prior_fun(batch_shape)
design_dict = self.design_fun(batch_shape)
design_dict.update(**kwargs)
sims_dict = self.simulator_fun(batch_shape, **prior_dict, **design_dict)
data = prior_dict | design_dict | sims_dict
data = {
key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()
}
return data
def batch_simulator(batch_shape, simulator_fun, **kwargs):
data = batched_call(simulator_fun, batch_shape, kwargs=kwargs, flatten=True)
data = tree_stack(data, axis=0, numpy=True)
return data
def rdm_experiment_simple(
v_intercept,
v_slope,
s_true,
s_false,
b,
t0,
num_obs,
rng
):
"""Simulates data from a single subject in a multi-alternative response times experiment."""
if np.any(np.array((v_intercept, v_slope, s_true, s_false, b, t0)) <= 0):
raise ValueError("Model parameters must be positive")
num_accumulators = 2
# Acc1 = false, Acc2 = true
v = np.hstack([v_intercept, v_intercept + v_slope])
s = np.hstack([s_false, s_true])
mu = b / v
lam = (b / s) ** 2
# First passage time
fpt = np.zeros((num_accumulators, num_obs))
for i in range(num_accumulators):
fpt[i, :] = rng.wald(mu[i], lam[i], size=num_obs)
resp = fpt.argmin(axis=0)
rt = fpt.min(axis=0) + t0
return {"x": np.c_[rt, resp]}
And here is my adapter:
inference_variables = ["x"]
inference_conditions = ["v_intercept", "v_slope", "s_true", "b", "t0"]
adapter = (bf.adapters.Adapter()
.to_array()
.convert_dtype("float64", "float32")
.drop("num_obs")
.concatenate(inference_variables, into="inference_variables")
.concatenate(inference_conditions, into="inference_conditions")
.broadcast("inference_conditions", to="inference_variables", expand=1)
)
—
Reply to this email directly, view it on GitHub
<#258 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADCW2AHJ5V2KYTKRFMO2SWL2B47DNAVCNFSM6AAAAABSJNAJ36VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIOJTHEYTONBVGM>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
I have checked and I don't think a In other words, the broadcast |
Thanks for looking into this! I could also modify my prior function (which requires broadcasting) to achieve the same result, but I thought using the broadcasting transform would be more elegegant. Perhaps it would be good to add an explicit exception when people use |
Yeah, something like this could make sense. Even better, we would need to throw an error whenever a variable is used for broadcasting that includes an inferred variable, which will not be present during inference time at the time broadcasting is called. It may take a while until we can implement this smart checking as we have to wait for other features but I think it would be very nice to have indeed. |
When using a
Adapter
with aBroadcast
transform whereto="inference_variables"
, training runs normally, but when I try sampling it complains that it cannot find the key"inference_variables"
which I did not include from theconditions
dictionary following the example notebooks.My guess is the
Broadcast
transform also needs astrict
argument in theforward
method for sampling as in theFilterTransform
class?The text was updated successfully, but these errors were encountered: