Skip to content

Commit

Permalink
add authors
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Aug 28, 2024
1 parent 23c48a7 commit 0f34d9e
Showing 1 changed file with 47 additions and 33 deletions.
80 changes: 47 additions & 33 deletions docs/source/notebooks/Example_4_Iowa_Gambling_Task_Short.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"metadata": {},
"source": [
"(example_4)=\n",
"# Example 4: Iowa-Gambling Task"
"# Example 4: Inferring optimistic bias in the Iowa-Gambling Task using variable autoconnection strength"
]
},
{
Expand All @@ -15,6 +15,17 @@
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_4_Iowa_Gambling_Task.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"authors:\n",
" - Aleksandrs Baskakovs, Aarhus University, Denmark ([email protected])\n",
" - Nicolas Legrand, Aarhus University, Denmark ([email protected])\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand All @@ -33,12 +44,11 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"from jax import grad\n",
"import numpy as np\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import seaborn as sns\n",
"from jax import jit\n",
"from jax import grad, jit\n",
"from jax.nn import softmax\n",
"from jax.tree_util import Partial\n",
"from pyhgf.math import binary_surprise\n",
Expand Down Expand Up @@ -693,22 +703,14 @@
"):\n",
"\n",
" # update the autoconnection strengths at the first level\n",
" network.attributes[4][\n",
" \"autoconnection_strength\"\n",
" ] = autoconnection_strength_1\n",
" network.attributes[5][\n",
" \"autoconnection_strength\"\n",
" ] = autoconnection_strength_2\n",
" network.attributes[6][\n",
" \"autoconnection_strength\"\n",
" ] = autoconnection_strength_1\n",
" network.attributes[7][\n",
" \"autoconnection_strength\"\n",
" ] = autoconnection_strength_2\n",
" \n",
" network.attributes[4][\"autoconnection_strength\"] = autoconnection_strength_1\n",
" network.attributes[5][\"autoconnection_strength\"] = autoconnection_strength_2\n",
" network.attributes[6][\"autoconnection_strength\"] = autoconnection_strength_1\n",
" network.attributes[7][\"autoconnection_strength\"] = autoconnection_strength_2\n",
"\n",
" # run the model forward\n",
" network.input_data(input_data=u, observed=observed)\n",
" \n",
"\n",
" # compute decision probabilities given the belief trajectories\n",
" expected_means = jnp.array(\n",
" [\n",
Expand All @@ -722,24 +724,24 @@
" for i in range(4, 8)\n",
" ]\n",
" )\n",
" \n",
"\n",
" # Compute the decision probabilities\n",
" x = beta_1 * expected_means + beta_2 * expected_variances\n",
" x -= jnp.max(x, axis=0)\n",
" decision_probabilities = softmax(x, axis=1)\n",
" \n",
"\n",
" # compute the binary surprise over each bandit x trials\n",
" surprises = binary_surprise(x=decisions.T, expected_mean=decision_probabilities)\n",
" \n",
"\n",
" # avoid numerical overflow\n",
" surprises = jnp.where(surprises > 1e6, 1e6, surprises)\n",
" \n",
"\n",
" # sum all the binary surprises\n",
" surprise = surprises.sum()\n",
" \n",
"\n",
" # returns inf if the model cannot fit somewhere\n",
" surprise = jnp.where(jnp.isnan(surprise), jnp.inf, surprise)\n",
" \n",
"\n",
" return -surprise"
]
},
Expand Down Expand Up @@ -809,7 +811,9 @@
"\n",
" def grad(self, inputs, output_gradients):\n",
" # Create a PyTensor expression of the gradient\n",
" grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_custom_op(*inputs)\n",
" grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_custom_op(\n",
" *inputs\n",
" )\n",
"\n",
" output_gradient = output_gradients[0]\n",
" # We reference the VJP Op created below, which encapsulates\n",
Expand All @@ -819,22 +823,30 @@
" output_gradient * grad_autoconnection_strength_2,\n",
" ]\n",
"\n",
"\n",
"class GradCustomOp(Op):\n",
" def make_node(self, autoconnection_strength_1, autoconnection_strength_2):\n",
" # Make sure the two inputs are tensor variables\n",
" inputs = [\n",
" pt.as_tensor_variable(autoconnection_strength_1), \n",
" pt.as_tensor_variable(autoconnection_strength_2), \n",
" pt.as_tensor_variable(autoconnection_strength_1),\n",
" pt.as_tensor_variable(autoconnection_strength_2),\n",
" ]\n",
" # Output has the shape type and shape as the first input\n",
" outputs = [inp.type() for inp in inputs]\n",
" return Apply(self, inputs, outputs)\n",
"\n",
" def perform(self, node, inputs, outputs):\n",
" grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_logp_fn(*inputs)\n",
" grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_logp_fn(\n",
" *inputs\n",
" )\n",
"\n",
" outputs[0][0] = np.asarray(\n",
" grad_autoconnection_strength_1, dtype=node.outputs[0].dtype\n",
" )\n",
" outputs[1][0] = np.asarray(\n",
" grad_autoconnection_strength_2, dtype=node.outputs[1].dtype\n",
" )\n",
"\n",
" outputs[0][0] = np.asarray(grad_autoconnection_strength_1, dtype=node.outputs[0].dtype)\n",
" outputs[1][0] = np.asarray(grad_autoconnection_strength_2, dtype=node.outputs[1].dtype)\n",
"\n",
"# Instantiate the Ops\n",
"custom_op = CustomOp()\n",
Expand Down Expand Up @@ -950,11 +962,13 @@
"source": [
"with pm.Model() as model:\n",
" autoconnection_strength = pm.Beta(\"autoconnection_strength\", 1.0, 1.0, shape=2)\n",
" pm.Potential(\"hgf\", custom_op(\n",
" autoconnection_strength_1=autoconnection_strength[0], \n",
" autoconnection_strength_2=autoconnection_strength[1]\n",
" pm.Potential(\n",
" \"hgf\",\n",
" custom_op(\n",
" autoconnection_strength_1=autoconnection_strength[0],\n",
" autoconnection_strength_2=autoconnection_strength[1],\n",
" ),\n",
" )\n",
" )\n",
" idata = pm.sample(chains=2, cores=1)"
]
},
Expand Down

0 comments on commit 0f34d9e

Please sign in to comment.