Skip to content

Commit

Permalink
Merge pull request #112 from jeremiahpslewis/master
Browse files Browse the repository at this point in the history
Update syntax / function calls for logistic reg
  • Loading branch information
cpfiffer authored Mar 16, 2021
2 parents 7402f77 + e8f8133 commit bd10a05
Showing 1 changed file with 11 additions and 30 deletions.
41 changes: 11 additions & 30 deletions 2_LogisticRegression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: [Turing]: progress logging is disabled globally\n",
"└ @ Turing /home/cameron/.julia/packages/Turing/cReBm/src/Turing.jl:22\n"
]
},
{
"data": {
"text/plain": [
"false"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Import Turing and Distributions.\n",
"using Turing, Distributions\n",
Expand All @@ -57,7 +38,7 @@
"Random.seed!(0);\n",
"\n",
"# Turn off progress monitor.\n",
"Turing.turnprogress(false)"
"Turing.setprogress!(false)"
]
},
{
Expand Down Expand Up @@ -170,8 +151,8 @@
],
"source": [
"# Convert \"Default\" and \"Student\" to numeric values.\n",
"data[!,:DefaultNum] = [r.Default == \"Yes\" ? 1.0 : 0.0 for r in eachrow(data)]\n",
"data[!,:StudentNum] = [r.Student == \"Yes\" ? 1.0 : 0.0 for r in eachrow(data)]\n",
"data[!,:DefaultNum] = (data[!, :Default] .== \"Yes\") .* 1.0\n",
"data[!,:StudentNum] = (data[!, :Student] .== \"Yes\") .* 1.0\n",
"\n",
"# Delete the old columns which say \"Yes\" and \"No\".\n",
"select!(data, Not([:Default, :Student]))\n",
Expand Down Expand Up @@ -19190,10 +19171,10 @@
"source": [
"function prediction(x::Matrix, chain, threshold)\n",
" # Pull the means from each parameter's sampled values in the chain.\n",
" intercept = mean(chain[:intercept].value)\n",
" student = mean(chain[:student].value)\n",
" balance = mean(chain[:balance].value)\n",
" income = mean(chain[:income].value)\n",
" intercept = mean(Array(chain[:intercept]))\n",
" student = mean(Array(chain[:student]))\n",
" balance = mean(Array(chain[:balance].data))\n",
" income = mean(Array(chain[:income].data))\n",
"\n",
" # Retrieve the number of rows.\n",
" n, _ = size(x)\n",
Expand Down Expand Up @@ -19301,9 +19282,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.4.0",
"display_name": "Julia 1.5.3",
"language": "julia",
"name": "julia-1.4"
"name": "julia-1.5"
},
"language_info": {
"file_extension": ".jl",
Expand Down

0 comments on commit bd10a05

Please sign in to comment.