-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keyword arg to modelmatrix; define momentmatrix #16
base: main
Are you sure you want to change the base?
Changes from 9 commits
e04ad3f
6f6a160
0625c3c
e514db0
bfe8ac6
af65888
11505ff
de0c6ae
c36352b
14ccc70
9a7b2ab
48333e5
e927f72
ab928fa
2858ba0
06206ad
7fa73cb
c3460ae
93f8742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,19 +34,25 @@ Return the mean of the response. | |
function meanresponse end | ||
|
||
""" | ||
modelmatrix(model::RegressionModel) | ||
modelmatrix(model::RegressionModel; weighted::Bool=false) | ||
|
||
Return the model matrix (a.k.a. the design matrix). | ||
Return the model matrix (a.k.a. the design matrix) or, if `weighted=true` the weighted | ||
model matrix, i.e. `X' * sqrt.(W)`, where `X` is the model matrix and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why transpose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad...I will fix it |
||
`W` is the diagonal matrix whose elements are the model weights. | ||
""" | ||
function modelmatrix end | ||
function modelmatrix(model::RegressionModel; weighted::Bool=false) end | ||
gragusa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
crossmodelmatrix(model::RegressionModel) | ||
crossmodelmatrix(model::RegressionModel; weighted::Bool=false) | ||
|
||
Return `X'X` where `X` is the model matrix of `model`. | ||
Return `X'X` where `X` is the model matrix of `model` or, if `weighted=true`, `X'WX`, | ||
gragusa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
where `W` is the diagonal matrix whose elements are the model weights. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we define weights? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could add a link to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How exactly do I add a link to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's in the same package so I think something like |
||
This function will return a pre-computed matrix stored in `model` if possible. | ||
""" | ||
crossmodelmatrix(model::RegressionModel) = (x = modelmatrix(model); Symmetric(x' * x)) | ||
function crossmodelmatrix(model::RegressionModel; weighted::Bool=false) | ||
x = weighted ? modelmatrix(model; weighted=weighted) : modelmatrix(model) | ||
return Symmetric(x' * x) | ||
end | ||
|
||
""" | ||
leverage(model::RegressionModel) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,13 +6,31 @@ using StatsAPI: RegressionModel, crossmodelmatrix | |||||||||||||||||||||||||||||||||||
struct MyRegressionModel <: RegressionModel | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
struct MyWeightedRegressionModel <: RegressionModel | ||||||||||||||||||||||||||||||||||||
wts::AbstractVector | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
StatsAPI.modelmatrix(::MyRegressionModel) = [1 2; 3 4] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
function StatsAPI.modelmatrix(r::MyWeightedRegressionModel; weighted::Bool=false) | ||||||||||||||||||||||||||||||||||||
X = [1 2; 3 4] | ||||||||||||||||||||||||||||||||||||
weighted ? sqrt.(r.wts).*X : X | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
w = [0.3, 0.2] | ||||||||||||||||||||||||||||||||||||
Comment on lines
+9
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the methods hardcode the matrix, probably not worth having a separate type which doesn't hardcode weights:
Suggested change
... and simplify tests below. |
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@testset "TestRegressionModel" begin | ||||||||||||||||||||||||||||||||||||
m = MyRegressionModel() | ||||||||||||||||||||||||||||||||||||
r = MyWeightedRegressionModel(w) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(m) == [10 14; 14 20] | ||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(m; weighted=false) == [10 14; 14 20] | ||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(m) isa Symmetric | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(r) == [10 14; 14 20] | ||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(r; weighted=false) == [10 14; 14 20] | ||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(r; weighted=true) ≈ [2.1 3.0; 3.0 4.4] | ||||||||||||||||||||||||||||||||||||
@test crossmodelmatrix(r; weighted=true) isa Symmetric | ||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
end # module TestRegressionModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.