Skip to content

Commit

Permalink
Update imports.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Nov 17, 2023
1 parent 80ba6f7 commit 3eb6951
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions notebooks/basic_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {
"id": "1OtVkBf7OAIF"
},
Expand All @@ -68,7 +68,6 @@
"import harmonic as hm\n",
"from functools import partial\n",
"import emcee\n",
"from harmonic import model_nf\n",
"import jax.numpy as jnp"
]
},
Expand All @@ -94,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {
"id": "8PgO6f4VQSpD"
},
Expand Down Expand Up @@ -132,7 +131,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {
"id": "us1kBuWlQZTy"
},
Expand Down Expand Up @@ -161,7 +160,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {
"id": "Eel3bSORQZW0"
},
Expand Down Expand Up @@ -210,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {
"id": "eSxxNW1KQZZc"
},
Expand All @@ -232,7 +231,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {
"id": "efPNgW8qQZcW"
},
Expand All @@ -256,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {
"id": "WrB47hA3QZfM"
},
Expand All @@ -265,7 +264,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Training NF, current loss: 6.950: 100%|██████████| 20/20 [01:51<00:00, 5.58s/it]\n"
"Training NF, current loss: 6.955: 80%|████████ | 16/20 [01:59<00:28, 7.24s/it]"
]
}
],
Expand All @@ -276,10 +275,10 @@
"n_unscaled_layers = 4\n",
"temperature = 0.9\n",
"\n",
"model = model_nf.RealNVPModel(ndim, standardize=True, temperature=temperature)\n",
"model = hm.model.RealNVPModel(ndim, standardize=True, temperature=temperature)\n",
"epochs_num = 20\n",
"# Train model\n",
"model.fit(jnp.array(chains_train.samples), epochs=epochs_num, verbose= True)"
"model.fit(chains_train.samples, epochs=epochs_num, verbose= True)"
]
},
{
Expand All @@ -295,7 +294,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -320,7 +319,7 @@
"samples = samples.reshape((-1, ndim))\n",
"samp_num = samples.shape[0]\n",
"flow_samples = model.sample(samp_num)\n",
"hm.utils_flows.plot_getdist_compare(samples, flow_samples)"
"hm.utils.plot_getdist_compare(samples, flow_samples)"
]
},
{
Expand All @@ -336,7 +335,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {
"id": "yP3AjUrCQZhh"
},
Expand Down Expand Up @@ -375,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"id": "YMQ9m747TAOT"
},
Expand Down Expand Up @@ -415,7 +414,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down

0 comments on commit 3eb6951

Please sign in to comment.