diff --git a/edward2/jax/nn/heteroscedastic_lib.py b/edward2/jax/nn/heteroscedastic_lib.py index 1370863d..55dff456 100644 --- a/edward2/jax/nn/heteroscedastic_lib.py +++ b/edward2/jax/nn/heteroscedastic_lib.py @@ -87,7 +87,7 @@ def _get_cov_layer_kernel_init(self): ) # Equivalent to the default kernel init `lecun_normal()` when we set the # scaling factor `scale_layer_kernel_init_factor` to 1. - # https://github.com/google/jax/blob/main/jax/_src/nn/initializers.py#L440 + # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L440 cov_layer_kernel_init = nn.initializers.variance_scaling( self.cov_layer_kernel_init_scale, 'fan_in',