From 58148df7890fe0e7af6e4fb1e2f81b3b153911c3 Mon Sep 17 00:00:00 2001 From: Olivier Labayle <48188914+olivierlabayle@users.noreply.github.com> Date: Thu, 4 Apr 2024 18:53:28 +0100 Subject: [PATCH] add support for composed estimands in from_param_file (#190) * add support for composed estimands in from_param_file * up manifest --- Manifest.toml | 170 +++++++++++++---------------- src/tl_inputs/from_actors.jl | 6 +- src/tl_inputs/from_param_files.jl | 81 +++++++++++--- src/tl_inputs/tl_inputs.jl | 68 ++++-------- test/tl_inputs/from_param_files.jl | 67 ++++++++++-- test/tl_inputs/test_utils.jl | 20 +++- test/tl_inputs/tl_inputs.jl | 97 ---------------- 7 files changed, 235 insertions(+), 274 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 4380233..3ffb0f4 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -218,12 +218,6 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" -[[deps.BenchmarkTools]] -deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "f1dff6729bc61f4d49e140da1af55dcd1ac97b2f" -uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.5.0" - [[deps.BioCore]] deps = ["Automa", "BufferedStreams", "YAML"] git-tree-sha1 = "476edbf4ef94594fff430a84ca96f86cb2327a71" @@ -304,9 +298,9 @@ version = "1.0.5" [[deps.CairoMakie]] deps = ["CRC32c", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools"] -git-tree-sha1 = "6dc1bbdd6a133adf4aa751d12dbc2c6ae59f873d" +git-tree-sha1 = "e64af6bea1f0dcde6ebecc581768074f992ad39b" uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.11.9" +version = "0.11.10" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -335,9 +329,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] -git-tree-sha1 = "6d4569d555704cdf91b3417c0667769a4a7cbaa2" +git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.14" +version = "0.1.15" [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" @@ -490,9 +484,9 @@ version = "0.17.6" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" +git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.4" +version = "1.5.5" weakdeps = ["IntervalSets", "StaticArrays"] [deps.ConstructionBase.extensions] @@ -506,9 +500,9 @@ uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.3" [[deps.Contour]] -git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.2" +version = "0.6.3" [[deps.CpuId]] deps = ["Markdown"] @@ -654,9 +648,9 @@ version = "1.0.4" [[deps.EvoTrees]] deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "e1107e45d7fe1a3c5dd335376bb6333b42cf9d1c" +git-tree-sha1 = "92d1f78f95f4794bf29bd972dacfa37ea1fec9f4" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.16.6" +version = "0.16.7" [deps.EvoTrees.extensions] EvoTreesCUDAExt = "CUDA" @@ -729,9 +723,9 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "c5c28c245101bd59154f649e19b038d15901b5dc" +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.2" +version = "1.16.3" [[deps.FilePaths]] deps = ["FilePathsBase", "MacroTools", "Reexport", "Requires"] @@ -749,10 +743,10 @@ version = "0.9.21" uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "5b93957f6dcd33fc343044af3d48c215be2562f1" +deps = ["LinearAlgebra"] +git-tree-sha1 = "bfe82a708416cf00b73a3198db0859c82f741558" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.9.3" +version = "1.10.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -789,9 +783,9 @@ uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" version = "2.13.93+0" [[deps.Format]] -git-tree-sha1 = "f3cf88025f6d03c194d73f5d13fee9004a108329" +git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc" uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -version = "1.3.6" +version = "1.3.7" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] @@ -910,9 +904,9 @@ version = "1.9.0" [[deps.GridLayoutBase]] deps = ["GeometryBasics", "InteractiveUtils", "Observables"] -git-tree-sha1 = "af13a277efd8a6e716d79ef635d5342ccb75be61" +git-tree-sha1 = "6f93a83ca11346771a93bbde2bdad2f65b61498f" uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" -version = "0.10.0" +version = "0.10.2" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -933,15 +927,15 @@ version = "0.17.1" [[deps.HDF5_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "e4591176488495bf44d7456bd73179d87d5e6eab" +git-tree-sha1 = "6384c847ff5056c5624e30e75b3ca48902cae0ac" uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+1" +version = "1.14.3+2" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "db864f2d91f68a5912937af80327d288ea1f3aee" +git-tree-sha1 = "8e59b47b9dc525b70550ca082ce85bcd7f5477cd" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.3" +version = "1.10.5" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -1075,9 +1069,9 @@ version = "0.15.1" [[deps.IntervalArithmetic]] deps = ["CRlibm_jll", "RoundingEmulator"] -git-tree-sha1 = "2d6d22fe481eff6e337808cc0880c567d7324f9a" +git-tree-sha1 = "552505ed27d2a90ff04c15b0ecf4634e0ab5547b" uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" -version = "0.22.8" +version = "0.22.9" weakdeps = ["DiffRules", "ForwardDiff", "RecipesBase"] [deps.IntervalArithmetic.extensions] @@ -1098,9 +1092,13 @@ weakdeps = ["Random", "RecipesBase", "Statistics"] [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +git-tree-sha1 = "896385798a8d49a255c398bd49162062e4a4c435" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.12" +version = "0.1.13" +weakdeps = ["Dates"] + + [deps.InverseFunctions.extensions] + DatesExt = "Dates" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -1212,9 +1210,9 @@ version = "3.100.1+0" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "ddab4d40513bce53c8e3157825e245224f74fae7" +git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.0" +version = "6.6.3" [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" @@ -1396,9 +1394,9 @@ version = "1.0.3" [[deps.LoopVectorization]] deps = ["ArrayInterface", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "0f5648fbae0d015e3abe5867bca2b362f67a5894" +git-tree-sha1 = "a13f3be5d84b9c95465d743c82af0b094ef9c2e2" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.166" +version = "0.12.169" weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [deps.LoopVectorization.extensions] @@ -1443,9 +1441,9 @@ version = "0.1.4" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "f4782ed751d4683a2858278ef2997130a82ca710" +git-tree-sha1 = "17d160e8f796ab5ceb4c017bc4019d21fd686a35" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "1.1.2" +version = "1.2.1" weakdeps = ["StatisticalMeasures"] [deps.MLJBase.extensions] @@ -1483,9 +1481,9 @@ version = "0.10.0" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "14bd8088cf7cd1676aa83a57004f8d23d43cd81e" +git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.5" +version = "1.9.6" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1495,9 +1493,9 @@ version = "0.16.16" [[deps.MLJTuning]] deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase", "StatisticalMeasuresBase"] -git-tree-sha1 = "42949f1a85c48f390cfce46cb1c4fcda1846b204" +git-tree-sha1 = "4a2c14b9529753db3ece53fd635c609220200507" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.8.2" +version = "0.8.4" [[deps.MLJXGBoostInterface]] deps = ["MLJModelInterface", "SparseArrays", "Tables", "XGBoost"] @@ -1541,10 +1539,10 @@ uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.13" [[deps.Makie]] -deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FilePaths", "FixedPointNumbers", "Format", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Scratch", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] -git-tree-sha1 = "27af6be179c711fb916a597b6644fbb5b80becc0" +deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FilePaths", "FixedPointNumbers", "Format", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Scratch", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] +git-tree-sha1 = "46ca613be7a1358fb93529726ea2fc28050d3ae0" uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.20.8" +version = "0.20.9" [[deps.MakieCore]] deps = ["Observables", "REPL"] @@ -1566,12 +1564,6 @@ version = "0.4.2" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[deps.MathOptInterface]] -deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test", "Unicode"] -git-tree-sha1 = "679c1aec6934d322783bd15db4d18f898653be4f" -uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -version = "1.27.0" - [[deps.MathTeXEngine]] deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "Test", "UnicodeFun"] git-tree-sha1 = "8f52dbaa1351ce4cb847d95568cb29e62a307d93" @@ -1648,12 +1640,6 @@ git-tree-sha1 = "8d852646862c96e226367ad10c8af56099b4047e" uuid = "3b2b4ff1-bcff-5658-a3ee-dbcf1ce5ac09" version = "0.4.4" -[[deps.MutableArithmetics]] -deps = ["LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "302fd161eb1c439e4115b51ae456da4e9984f130" -uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" -version = "1.4.1" - [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" @@ -1662,9 +1648,9 @@ version = "7.8.3" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" +git-tree-sha1 = "1fa1a14766c60e66ab22e242d45c1857c83a3805" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.12" +version = "0.9.13" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1772,9 +1758,9 @@ version = "1.4.2" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "60e3045590bd104a16fefb12836c00c0ef8c7f8c" +git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.13+0" +version = "3.0.13+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1783,10 +1769,16 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.5+0" [[deps.Optim]] -deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "MathOptInterface", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "d024bfb56144d947d4fafcd9cb5cafbe3410b133" +deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89" uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.9.2" +version = "1.9.4" + + [deps.Optim.extensions] + OptimMOIExt = "MathOptInterface" + + [deps.Optim.weakdeps] + MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1858,12 +1850,6 @@ git-tree-sha1 = "eb3f9df2457819bf0a9019bd93cc451697a0751e" uuid = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" version = "0.4.20" -[[deps.PikaParser]] -deps = ["DocStringExtensions"] -git-tree-sha1 = "d6ff87de27ff3082131f31a714d25ab6d0a88abf" -uuid = "3bbf5609-3e7b-44cd-8549-7c69f321e792" -version = "0.6.1" - [[deps.Pixman_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" @@ -1903,7 +1889,6 @@ deps = ["LinearAlgebra", "RecipesBase", "Setfield", "SparseArrays"] git-tree-sha1 = "a9c7a523d5ed375be3983db190f6a5874ae9286d" uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" version = "4.0.6" -weakdeps = ["ChainRulesCore", "FFTW", "MakieCore", "MutableArithmetics"] [deps.Polynomials.extensions] PolynomialsChainRulesCoreExt = "ChainRulesCore" @@ -1911,6 +1896,12 @@ weakdeps = ["ChainRulesCore", "FFTW", "MakieCore", "MutableArithmetics"] PolynomialsMakieCoreExt = "MakieCore" PolynomialsMutableArithmeticsExt = "MutableArithmetics" + [deps.Polynomials.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" + MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" + [[deps.PooledArrays]] deps = ["DataAPI", "Future"] git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" @@ -1961,10 +1952,6 @@ version = "0.5.6" deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -[[deps.Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" - [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" @@ -2268,12 +2255,6 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" -[[deps.StableHashTraits]] -deps = ["Compat", "PikaParser", "SHA", "Tables", "TupleTools"] -git-tree-sha1 = "10dc702932fe05a0e09b8e5955f00794ea1e8b12" -uuid = "c5dd0088-6c3f-4803-b00e-f31a60c170fa" -version = "1.1.8" - [[deps.StableRNGs]] deps = ["Random", "Test"] git-tree-sha1 = "ddc1a7b85e760b5285b50b882fa91e40c603be47" @@ -2358,9 +2339,9 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.2" +version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -2501,9 +2482,9 @@ version = "0.6.8" [[deps.TimeZones]] deps = ["Artifacts", "Dates", "Downloads", "InlineStrings", "LazyArtifacts", "Mocking", "Printf", "Scratch", "TZJData", "Unicode", "p7zip_jll"] -git-tree-sha1 = "89e64d61ef3cd9e80f7fc12b7d13db2d75a23c03" +git-tree-sha1 = "cc54d5c9803309474014a8955a96e4adcd11bcf4" uuid = "f269a46b-ccf7-5d73-abea-4c690281aa53" -version = "1.13.0" +version = "1.14.0" weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] @@ -2540,11 +2521,6 @@ git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" uuid = "981d1d27-644d-49a2-9326-4793e63143c3" version = "0.1.0" -[[deps.TupleTools]] -git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd" -uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.5.0" - [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" @@ -2630,9 +2606,9 @@ version = "2.0.1+0" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "07e470dabc5a6a4254ffebc29a1b3fc01464e105" +git-tree-sha1 = "532e22cf7be8462035d092ff21fada7527e2c488" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.12.5+0" +version = "2.12.6+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -2642,9 +2618,9 @@ version = "1.1.34+0" [[deps.XZ_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "31c421e5516a6248dfb22c194519e37effbf1f30" +git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632" uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800" -version = "5.6.1+0" +version = "5.4.6+0" [[deps.Xorg_libX11_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] @@ -2707,9 +2683,9 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c" +git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.5+0" +version = "1.5.6+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] diff --git a/src/tl_inputs/from_actors.jl b/src/tl_inputs/from_actors.jl index 2fc8315..0d61940 100644 --- a/src/tl_inputs/from_actors.jl +++ b/src/tl_inputs/from_actors.jl @@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data) end function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.) - freqs = TargeneCore.frequency_table(data, treatments) + freqs = TMLE.frequency_table(data, treatments) # This loop adds all ATE estimands where all other treatments than # the bQTL are fixed, at the order 1, this is the simple bQTL's ATE for setting in control_case_settings(TMLE.StatisticalATE, treatments, data) @@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]), outcome_extra_covariates = variables.covariates ) - if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) + if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) update_estimands_from_outcomes!(estimands, Ψ, variables.targets) end end @@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]), outcome_extra_covariates = variables.covariates ) - if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) + if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint) update_estimands_from_outcomes!(estimands, Ψ, variables.targets) end end diff --git a/src/tl_inputs/from_param_files.jl b/src/tl_inputs/from_param_files.jl index c70ec29..2f3a79a 100644 --- a/src/tl_inputs/from_param_files.jl +++ b/src/tl_inputs/from_param_files.jl @@ -16,6 +16,7 @@ MismatchedCaseControlEncodingError() = NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint)) +MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables.")) function check_genotypes_encoding(val::NamedTuple, type) if !(typeof(val.case) <: type && typeof(val.control) <: type) @@ -27,17 +28,66 @@ check_genotypes_encoding(val::T, type) where T = T <: type || throw(MismatchedCaseControlEncodingError()) +get_treatments(Ψ) = keys(Ψ.treatment_values) + +function get_treatments(Ψ::ComposedEstimand) + treatments = get_treatments(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments")) + end + end + return treatments +end + +get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf ∈ Ψ.treatment_confounders))) + +function get_confounders(Ψ::ComposedEstimand) + confounders = get_confounders(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders")) + end + end + return confounders +end + +get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates + +function get_outcome_extra_covariates(Ψ::ComposedEstimand) + outcome_extra_covariates = get_outcome_extra_covariates(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates")) + end + end + return outcome_extra_covariates +end + +get_outcome(Ψ) = Ψ.outcome + +function get_outcome(Ψ::ComposedEstimand) + outcome = get_outcome(first(Ψ.args)) + if length(Ψ.args) > 1 + for arg in Ψ.args[2:end] + get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome")) + end + end + return outcome +end + function get_variables(estimands, traits, pcs) genetic_variants = Set{Symbol}() others = Set{Symbol}() pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs))) alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits))) for Ψ in estimands - treatments = keys(Ψ.treatment_values) - confounders = Iterators.flatten((Tconf for Tconf ∈ Ψ.treatment_confounders)) + treatments = get_treatments(Ψ) + confounders = get_confounders(Ψ) + outcome_extra_covariates = get_outcome_extra_covariates(Ψ) push!( others, - Ψ.outcome_extra_covariates..., + outcome_extra_covariates..., confounders..., treatments... ) @@ -123,6 +173,8 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates) end +adjust_parameter_sections(Ψ::ComposedEstimand, variants_alleles, pcs) = + ComposedEstimand(Ψ.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args)) function append_from_valid_estimands!( estimands::Vector{<:TMLE.Estimand}, @@ -136,29 +188,28 @@ function append_from_valid_estimands!( # Update treatment's and confounders's sections of Ψ Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs) # Update frequency tables with current treatments - treatments = sorted_treatment_names(Ψ) + treatments = get_treatments(Ψ) if !haskey(frequency_tables, treatments) - frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments)) + frequency_tables[treatments] = TMLE.frequency_table(data, treatments) end # Check if parameter satisfies positivity - satisfies_positivity(Ψ, frequency_tables[treatments]; - positivity_constraint=positivity_constraint) || return - # Expand wildcard to all outcomes - if Ψ.outcome === :ALL - update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes) - else - # Ψ.target || MissingVariableError(variable) - push!(estimands, Ψ) + if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint) + # Expand wildcard to all outcomes + if get_outcome(Ψ) === :ALL + update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes) + else + push!(estimands, Ψ) + end end end function adjusted_estimands(estimands, variables, data; positivity_constraint=0.) final_estimands = TMLE.Estimand[] variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants) - freqency_tables = Dict() + frequency_tables = Dict() for Ψ in estimands # If the genotypes encoding is a string representation make sure they match the actual genotypes - append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint) + append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint) end length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint)) diff --git a/src/tl_inputs/tl_inputs.jl b/src/tl_inputs/tl_inputs.jl index 2ede39b..11119f9 100644 --- a/src/tl_inputs/tl_inputs.jl +++ b/src/tl_inputs/tl_inputs.jl @@ -64,6 +64,7 @@ NotAllVariantsFoundError(rsids) = ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", "))) NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased.")) + """ bgen_files(snps, bgen_prefix) @@ -103,47 +104,8 @@ function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString}, return genotypes end -sorted_treatment_names(Ψ) = tuple(sort(collect(keys(Ψ.treatment_values)))...) - -function setting_iterator(Ψ::TMLE.StatisticalIATE) - treatments = sorted_treatment_names(Ψ) - return ( - NamedTuple{treatments}(collect(Tval)) for - Tval in Iterators.product((values(Ψ.treatment_values[T]) for T in treatments)...) - ) -end - -function setting_iterator(Ψ::TMLE.StatisticalATE) - treatments = sorted_treatment_names(Ψ) - return ( - NamedTuple{treatments}([(Ψ.treatment_values[T][c]) for T in treatments]) - for c in (:case, :control) - ) -end - -function setting_iterator(Ψ::TMLE.StatisticalCM) - treatments = sorted_treatment_names(Ψ) - return (NamedTuple{treatments}(Ψ.treatment_values[T] for T in treatments), ) -end - -function satisfies_positivity(Ψ::TMLE.Estimand, freqs; positivity_constraint=0.01) - for base_setting in setting_iterator(Ψ) - if !haskey(freqs, base_setting) || freqs[base_setting] < positivity_constraint - return false - end - end - return true -end - -function frequency_table(data, treatments::AbstractVector) - treatments = sort(treatments) - freqs = Dict() - N = nrow(data) - for (key, group) in pairs(groupby(data, treatments; skipmissing=true)) - freqs[NamedTuple(key)] = nrow(group) / N - end - return freqs -end +TMLE.satisfies_positivity(Ψ::ComposedEstimand, freqs; positivity_constraint=0.01) = + all(TMLE.satisfies_positivity(arg, freqs; positivity_constraint=positivity_constraint) for arg in Ψ.args) read_txt_file(path::Nothing) = nothing read_txt_file(path) = CSV.read(path, DataFrame, header=false)[!, 1] @@ -164,15 +126,27 @@ function merge(traits, pcs, genotypes) ) end +estimand_with_new_outcome(Ψ::T, outcome) where T = T( + outcome=outcome, + treatment_values=Ψ.treatment_values, + treatment_confounders=Ψ.treatment_confounders, + outcome_extra_covariates=Ψ.outcome_extra_covariates +) + function update_estimands_from_outcomes!(estimands, Ψ::T, outcomes) where T for outcome in outcomes push!( - estimands, - T( - outcome=outcome, - treatment_values=Ψ.treatment_values, - treatment_confounders=Ψ.treatment_confounders, - outcome_extra_covariates=Ψ.outcome_extra_covariates) + estimands, + estimand_with_new_outcome(Ψ, outcome) + ) + end +end + +function update_estimands_from_outcomes!(estimands, Ψ::ComposedEstimand, outcomes) + for outcome in outcomes + push!( + estimands, + ComposedEstimand(Ψ.f, Tuple(estimand_with_new_outcome(arg, outcome) for arg in Ψ.args)) ) end end diff --git a/test/tl_inputs/from_param_files.jl b/test/tl_inputs/from_param_files.jl index 2215245..8e9c153 100644 --- a/test/tl_inputs/from_param_files.jl +++ b/test/tl_inputs/from_param_files.jl @@ -24,6 +24,40 @@ include(joinpath(TESTDIR, "tl_inputs", "test_utils.jl")) pcs = TargeneCore.read_csv_file(joinpath(TESTDIR, "data", "pcs.csv")) # extraW, extraT, extraC are parsed from all param_files estimands = make_estimands_configuration().estimands + # get_treatments, get_outcome, ... + ## Simple Estimand + Ψ = estimands[1] + @test TargeneCore.get_outcome(Ψ) == :ALL + @test TargeneCore.get_treatments(Ψ) == keys(Ψ.treatment_values) + @test TargeneCore.get_confounders(Ψ) == () + @test TargeneCore.get_outcome_extra_covariates(Ψ) == () + ## ComposedEstimand + Ψ = estimands[5] + @test TargeneCore.get_outcome(Ψ) == :ALL + @test TargeneCore.get_treatments(Ψ) == keys(Ψ.args[1].treatment_values) + @test TargeneCore.get_confounders(Ψ) == () + @test TargeneCore.get_outcome_extra_covariates(Ψ) == (Symbol("22001"), ) + ## Bad ComposedEstimand + Ψ = ComposedEstimand( + TMLE.joint_estimand, ( + CM( + outcome = "Y1", + treatment_values = (RSID_3 = "GG", RSID_198 = "AG"), + treatment_confounders = (RSID_3 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + ), + CM( + outcome = "Y2", + treatment_values = (RSID_2 = "AA", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [:PC1], RSID_198 = []), + outcome_extra_covariates = [] + )) + ) + @test_throws ArgumentError TargeneCore.get_outcome(Ψ) == :ALL + @test_throws ArgumentError TargeneCore.get_treatments(Ψ) + @test_throws ArgumentError TargeneCore.get_confounders(Ψ) + @test_throws ArgumentError TargeneCore.get_outcome_extra_covariates(Ψ) + # get_variables variables = TargeneCore.get_variables(estimands, traits, pcs) @test variables.genetic_variants == Set([:RSID_198, :RSID_2]) @test variables.outcomes == Set([:BINARY_1, :CONTINUOUS_2, :CONTINUOUS_1, :BINARY_2]) @@ -38,8 +72,9 @@ end ) pcs = Set([:PC1, :PC2]) variants_alleles = Dict(:RSID_198 => Set(genotypes.RSID_198)) - # AG is not in the genotypes but GA is - Ψ = make_estimands_configuration().estimands[4] + estimands = make_estimands_configuration().estimands + # RS198 AG is not in the genotypes but GA is + Ψ = estimands[4] @test Ψ.treatment_values.RSID_198 == (case="AG", control="AA") new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) @test new_Ψ.outcome == Ψ.outcome @@ -50,6 +85,19 @@ end RSID_2 = (case = "AA", control = "GG") ) + # ComnposedEstimand + Ψ = estimands[5] + @test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG") + @test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA") + new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) + for index in 1:length(Ψ.args) + @test new_Ψ.args[index].outcome == Ψ.args[index].outcome + @test new_Ψ.args[index].outcome_extra_covariates == (Symbol(22001),) + @test new_Ψ.args[index].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2),) + end + @test new_Ψ.args[1].treatment_values == (RSID_198 = "GA", RSID_2 = "GG") + @test new_Ψ.args[2].treatment_values == (RSID_198 = "GA", RSID_2 = "AA") + # If the allele is not present variants_alleles = Dict(:RSID_198 => Set(["AA"])) @test_throws TargeneCore.AbsentAlleleError("RSID_198", "AG") TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs) @@ -95,8 +143,8 @@ end ## Estimands file: output_estimands = deserialize("final.estimands.jls").estimands - # There are 5 initial estimands containing a * - # Those are duplicated for each of the 4 targets. + # There are 5 initial estimands containing a :ALL + # Those are duplicated for each of the 4 outcomes. @test length(output_estimands) == 20 # In all cases the PCs are appended to the confounders. for Ψ ∈ output_estimands @@ -120,10 +168,11 @@ end @test Ψ.outcome_extra_covariates == (Symbol("22001"),) # Input Estimand 5: GA is corrected to AG to match the data - elseif Ψ isa TMLE.StatisticalCM && Ψ.treatment_values == (RSID_198 = "AG", RSID_2 = "GG") - @test Ψ.treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2)) - @test Ψ.outcome_extra_covariates == (Symbol("22001"),) - + elseif Ψ isa TMLE.ComposedEstimand + @test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG") + @test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA") + @test Ψ.args[1].treatment_confounders == Ψ.args[2].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2)) + @test Ψ.args[1].outcome_extra_covariates == Ψ.args[2].outcome_extra_covariates == (Symbol("22001"),) else throw(AssertionError(string("Which input did this output come from: ", Ψ))) end @@ -142,7 +191,7 @@ end tl_inputs(parsed_args) # The IATES are the most sensitives outestimands = deserialize("final.estimands.jls").estimands - @test all(Ψ isa Union{TMLE.StatisticalCM, TMLE.StatisticalATE} for Ψ in outestimands) + @test all(Ψ isa Union{TMLE.StatisticalCM, TMLE.StatisticalATE, ComposedEstimand} for Ψ in outestimands) @test size(outestimands, 1) == 16 cleanup() diff --git a/test/tl_inputs/test_utils.jl b/test/tl_inputs/test_utils.jl index 038e709..8783cda 100644 --- a/test/tl_inputs/test_utils.jl +++ b/test/tl_inputs/test_utils.jl @@ -6,7 +6,6 @@ function cleanup(;prefix="final.") end end - function make_estimands_configuration() estimands = [ IATE( @@ -32,11 +31,20 @@ function make_estimands_configuration() treatment_confounders = (RSID_2 = [], RSID_198 = []), outcome_extra_covariates = [22001] ), - CM( - outcome = "ALL", - treatment_values = (RSID_2 = "GG", RSID_198 = "GA"), - treatment_confounders = (RSID_2 = [], RSID_198 = []), - outcome_extra_covariates = [22001] + ComposedEstimand( + TMLE.joint_estimand, ( + CM( + outcome = "ALL", + treatment_values = (RSID_2 = "GG", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + ), + CM( + outcome = "ALL", + treatment_values = (RSID_2 = "AA", RSID_198 = "AG"), + treatment_confounders = (RSID_2 = [], RSID_198 = []), + outcome_extra_covariates = [22001] + )) ) ] return Configuration(estimands=estimands) diff --git a/test/tl_inputs/tl_inputs.jl b/test/tl_inputs/tl_inputs.jl index efb9541..4ded885 100644 --- a/test/tl_inputs/tl_inputs.jl +++ b/test/tl_inputs/tl_inputs.jl @@ -83,103 +83,6 @@ end @test_throws TargeneCore.NotAllVariantsFoundError(variants) TargeneCore.call_genotypes(bgen_dir, variants, 0.95;) end - -@testset "Test positivity_constraint" begin - data = DataFrame( - A = [1, 1, 0, 1, 0, 2, 2, 1], - B = ["AC", "CC", "AA", "AA", "AA", "AA", "AA", "AA"] - ) - ## One variable - freqs = TargeneCore.frequency_table(data, [:A]) - @test freqs == Dict( - (A = 0,) => 0.25, - (A = 2,) => 0.25, - (A = 1,) => 0.5 - ) - Ψ = CM( - outcome = :toto, - treatment_values = (A=1,), - treatment_confounders = (A=[],) - ) - @test TargeneCore.setting_iterator(Ψ) == ((A = 1,),) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.4) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.6) == false - - Ψ = ATE( - outcome = :toto, - treatment_values= (A= (case=1, control=0),), - treatment_confounders = (A=[],) - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [(A = 1,), (A = 0,)] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.2) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.3) == false - - ## Two variables - # Treatments are sorted: [:B, :A] -> [:A, :B] - freqs = TargeneCore.frequency_table(data, [:B, :A]) - @test freqs == Dict( - (A = 1, B = "CC") => 0.125, - (A = 1, B = "AA") => 0.25, - (A = 0, B = "AA") => 0.25, - (A = 1, B = "AC") => 0.125, - (A = 2, B = "AA") => 0.25 - ) - - Ψ = CM( - outcome = :toto, - treatment_values = (B = "CC", A = 1), - treatment_confounders = (B = [], A = []) - ) - @test TargeneCore.setting_iterator(Ψ) == ((A = 1, B = "CC"),) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.15) == false - - Ψ = ATE( - outcome = :toto, - treatment_values = (B=(case="AA", control="AC"), A=(case=1, control=1),), - treatment_confounders = (B = (), A = (),) - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [(A = 1, B = "AA"), (A = 1, B = "AC")] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.2) == false - - Ψ = IATE( - outcome = :toto, - treatment_values = (B=(case="AC", control="AA"), A=(case=1, control=0),), - treatment_confounders = (B=(), A=()), - ) - @test collect(TargeneCore.setting_iterator(Ψ)) == [ - (A = 1, B = "AC") (A = 1, B = "AA") - (A = 0, B = "AC") (A = 0, B = "AA")] - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=1.) == false - freqs = Dict( - (A = 1, B = "CC") => 0.125, - (A = 1, B = "AA") => 0.25, - (A = 0, B = "AA") => 0.25, - (A = 0, B = "AC") => 0.25, - (A = 1, B = "AC") => 0.125, - (A = 2, B = "AA") => 0.25 - ) - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.3) == false - @test TargeneCore.satisfies_positivity(Ψ, freqs, positivity_constraint=0.1) == true - - Ψ = IATE( - outcome = :toto, - treatment_values = (B=(case="AC", control="AA"), A=(case=1, control=0), C=(control=0, case=2)), - treatment_confounders = (B=(), A=(), C=()) - ) - expected_settings = Set([ - (A = 1, B = "AC", C = 0), - (A = 0, B = "AC", C = 0), - (A = 1, B = "AA", C = 0), - (A = 0, B = "AA", C = 0), - (A = 1, B = "AC", C = 2), - (A = 0, B = "AC", C = 2), - (A = 1, B = "AA", C = 2), - (A = 0, B = "AA", C = 2)]) - @test expected_settings == Set(TargeneCore.setting_iterator(Ψ)) -end - end true \ No newline at end of file