Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predicted_draws for a brms model does not bring in .chain or .iteration #301

Open
rundel opened this issue Sep 7, 2022 · 8 comments
Open

Comments

@rundel
Copy link

rundel commented Sep 7, 2022

My expectation is that predicted_draws() and related functions should be including the .chain and .iteration details from the model but currently they are just NA. The values can be recovered via a join but it is a bit of a headache and involves an unnecessary gather_draws() or similar.

See the reprex below,

d = data.frame(
  x = rnorm(100),
  y = rnorm(100)
)

b = brms::brm(y~x, data=d, silent=2, refresh=0)

tidybayes::predicted_draws(b, d)
#> # A tibble: 400,000 × 7
#> # Groups:   x, y, .row [100]
#>        x      y  .row .chain .iteration .draw .prediction
#>    <dbl>  <dbl> <int>  <int>      <int> <int>       <dbl>
#>  1 0.132 -0.883     1     NA         NA     1       0.573
#>  2 0.132 -0.883     1     NA         NA     2      -0.293
#>  3 0.132 -0.883     1     NA         NA     3       0.367
#>  4 0.132 -0.883     1     NA         NA     4      -0.924
#>  5 0.132 -0.883     1     NA         NA     5      -0.547
#>  6 0.132 -0.883     1     NA         NA     6       0.411
#>  7 0.132 -0.883     1     NA         NA     7      -0.138
#>  8 0.132 -0.883     1     NA         NA     8      -1.27 
#>  9 0.132 -0.883     1     NA         NA     9       1.20 
#> 10 0.132 -0.883     1     NA         NA    10      -0.361
#> # … with 399,990 more rows

This also appears to be the case with models from rstanarm as well.

@mjskay
Copy link
Owner

mjskay commented Sep 7, 2022

Yes, this is because posterior_predict() for those models does not return chain/iteration information. In some cases this could be retrofitted onto the resulting object, though I'm not sure it can be done in all cases (eg with subsampling). Since there wasn't a generic solution I was sure worked for all parameters that might be passed down to these functions, I opted not to include that information in case it would be wrong in some cases. I'm willing to be convinced otherwise if there's a reliable solution.

If you are using the full set of draws (no subsampling), I believe .chain should be floor(.draw/(n_draws/n_chains)) + 1 and .iteration should be .draw %% (n_draws/n_chains) + 1 but I would double check that.

@rundel
Copy link
Author

rundel commented Sep 8, 2022

Thanks for the clarification, I had not considered the issues around ndraws / draw_ids - this has lead me to do some digging in brms to wrap my head around what is happening with posterior_predict().

It seems like some of the issue is that posterior_predict() is calling prepare_predictions() which then calls posterior::as_draws_matrix() which seems to initially preserve draw "ids" but these are eventually lost due to some weirdness around how subset_draws() works and how it "repairs" the draw ids so for instances where ndraws / draw_ids are used the original ids are completely lost and the resulting draws have indexs from 1 to n. The behavior of posterior seems a bit bizarre to me but I'm sure there are reasons for these specific behaviors.

With all of that said it does seem like it would be possible to provide the .chain and .iteration information in cases where ndraws = NULL and draw_ids=NULL since it seems possible to match on the .draw to .draw in as_tibble(as_draws_df(b)) to recover the .chain and .iteration.

The formulas provided above make sense to me but seems potentially fragile if there was ever any weirdness around ordering vs. just using posterior + the brmsfit object to fill in the blanks.

@rundel
Copy link
Author

rundel commented Sep 8, 2022

One other quick though I just had - in the case of draw_ids the function(s) will already have the draw ids which case they can then be used to recover .chain and .iterations.

In the case of ndraws instead of letting prepare_predictions() handle the conversion of ndraws in `draw_ids, see here
this could be done by a similar call in the tidybayes function(s). In which case the above option should again work.

@rundel rundel changed the title predicted_draws for a brms model does not correctly bring in .chain or .iteration predicted_draws for a brms model does not bring in .chain or .iteration Sep 8, 2022
@mjskay
Copy link
Owner

mjskay commented Sep 8, 2022

In the case of ndraws instead of letting prepare_predictions() handle the conversion of ndraws in `draw_ids, see here
this could be done by a similar call in the tidybayes function(s). In which case the above option should again work.

Not a bad idea. This could be a good way to handle random subsets as well, rather than the current method which may be fragile in some cases.

@WillTirone
Copy link

I'm going to attempt a fix for this. I haven't dug in deep yet, but please let me know @mjskay if anything fundamental has changed since the past comments. Otherwise I'll proceed with the above! Thanks.

@mjskay
Copy link
Owner

mjskay commented Sep 8, 2023

Sure, would love a fix! I'd probably double check with @paul-buerkner to see if there's a canonical way to get chain and iteration info out of posterior_predict() and the like

@paul-buerkner
Copy link

I think this is related to paul-buerkner/brms#1534. We likely have to wait until brms 3.0 for this feature.

@WillTirone
Copy link

Makes sense, thank you @paul-buerkner. We have some code that's a bit of a workaround ( see below), I assume the preference is waiting until brms 3.0 rather than a temporary fix in tidybayes?

fix_draws = function(object, newdata, ..., func = tidybayes::predicted_draws) {
  draws = func(object, newdata, ...)

  n = names(draws)

  dplyr::full_join(
    draws |> dplyr::select(-.chain, -.iteration),
    tidybayes::tidy_draws(object) |>
      dplyr::select(.chain, .iteration, .draw),
    by = ".draw"
  ) |>
    dplyr::select(dplyr::all_of(n)) |>
    dplyr::ungroup()
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants