Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data adapter: Broadcasting causes issue with sampling #258

Open
maltelueken opened this issue Nov 22, 2024 · 6 comments
Open

Data adapter: Broadcasting causes issue with sampling #258

maltelueken opened this issue Nov 22, 2024 · 6 comments
Labels
feature New feature or request sugar Syntactical sugar or quality of life improvements

Comments

@maltelueken
Copy link

When using a Adapter with a Broadcast transform where to="inference_variables", training runs normally, but when I try sampling it complains that it cannot find the key "inference_variables" which I did not include from the conditions dictionary following the example notebooks.
My guess is the Broadcast transform also needs a strict argument in the forward method for sampling as in the FilterTransform class?

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Nov 22, 2024 via email

@maltelueken
Copy link
Author

maltelueken commented Nov 22, 2024

Sure! Here is my simulator:

class CustomSimulator(bf.simulators.Simulator):
    def __init__(self, prior_fun: Callable, design_fun: Callable, simulator_fun: Callable):
        self.prior_fun = prior_fun
        self.design_fun = design_fun
        self.simulator_fun = simulator_fun


    def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
        prior_dict = self.prior_fun(batch_shape)

        design_dict = self.design_fun(batch_shape)

        design_dict.update(**kwargs)

        sims_dict = self.simulator_fun(batch_shape, **prior_dict, **design_dict)

        data = prior_dict | design_dict | sims_dict

        data = {
            key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()
        }

        return data

def batch_simulator(batch_shape, simulator_fun, **kwargs):
    data = batched_call(simulator_fun, batch_shape, kwargs=kwargs, flatten=True)
    data = tree_stack(data, axis=0, numpy=True)
    return data


def rdm_experiment_simple(
    v_intercept,
    v_slope,
    s_true,
    s_false,
    b,
    t0,
    num_obs,
    rng
):
    if np.any(np.array((v_intercept, v_slope, s_true, s_false, b, t0)) <= 0):
        raise ValueError("Model parameters must be positive")

    num_accumulators = 2

    # Acc1 = false, Acc2 = true
    v = np.hstack([v_intercept, v_intercept + v_slope])
    s = np.hstack([s_false, s_true])

    mu = b / v
    lam = (b / s) ** 2

    # First passage time
    fpt = np.zeros((num_accumulators, num_obs))
    
    for i in range(num_accumulators):
        fpt[i, :] = rng.wald(mu[i], lam[i], size=num_obs)

    resp = fpt.argmin(axis=0)
    rt = fpt.min(axis=0) + t0

    return {"x": np.c_[rt, resp]}

And here is my adapter:

inference_variables = ["x"]
inference_conditions = ["v_intercept", "v_slope", "s_true", "b", "t0"]

adapter = (bf.adapters.Adapter()
    .to_array()
    .convert_dtype("float64", "float32")
    .drop("num_obs")
    .concatenate(inference_variables, into="inference_variables")
    .concatenate(inference_conditions, into="inference_conditions")
    .broadcast("inference_conditions", to="inference_variables", expand=1)
)

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Nov 22, 2024 via email

@paul-buerkner
Copy link
Contributor

paul-buerkner commented Nov 22, 2024

I have checked and I don't think a strict argument would fix it. The issue is that without the inference variables (inferred during approximator.sample), we cannot broadcast properly, which we need for the forward pass. So we create a chicken and egg problem bascially.

In other words, the broadcast to argument should not receive inference variables if possible. I will add this to the documentation. For your example, what would be the easiest way for you to broadcast as you intented without using to = "inference_variables?

@maltelueken
Copy link
Author

maltelueken commented Nov 24, 2024

Thanks for looking into this! I could also modify my prior function (which requires broadcasting) to achieve the same result, but I thought using the broadcasting transform would be more elegegant. Perhaps it would be good to add an explicit exception when people use to="inference_variables"? The error message that appears is not super informative on what the actual issue is I think. You also only run into this issue after training the approximator so a fail-fast solution would be a good idea.

@paul-buerkner
Copy link
Contributor

Yeah, something like this could make sense. Even better, we would need to throw an error whenever a variable is used for broadcasting that includes an inferred variable, which will not be present during inference time at the time broadcasting is called. It may take a while until we can implement this smart checking as we have to wait for other features but I think it would be very nice to have indeed.

@paul-buerkner paul-buerkner added feature New feature or request sugar Syntactical sugar or quality of life improvements labels Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request sugar Syntactical sugar or quality of life improvements
Projects
None yet
Development

No branches or pull requests

2 participants