diff --git a/src/hepconvert/copy_root.py b/src/hepconvert/copy_root.py index b0e1bac..81a7a29 100644 --- a/src/hepconvert/copy_root.py +++ b/src/hepconvert/copy_root.py @@ -251,19 +251,12 @@ def copy_root( ) } ) - for key in group: - if key in kb: - del chunk[key] + for key in group: + if key in kb: + del chunk[key] if first: first = False - if drop_branches: - branch_types = { - name: array.type - for name, array in chunk.items() - if name not in drop_branches - } - else: - branch_types = {name: array.type for name, array in chunk.items()} + branch_types = {name: array.type for name, array in chunk.items()} of.mktree( tree.name, branch_types, diff --git a/tests/test_copy_root.py b/tests/test_copy_root.py index 623f2e0..2f15b5a 100644 --- a/tests/test_copy_root.py +++ b/tests/test_copy_root.py @@ -66,13 +66,14 @@ def test_keep_branches(tmp_path): hepconvert.copy_root( Path(tmp_path) / "drop_branches.root", skhep_testdata.data_path("uproot-HZZ.root"), - drop_branches=["Jet_*", "MClepton_*"], + keep_branches="MClepton_*", counter_name=lambda counted: "N" + counted, force=True, ) original = uproot.open(skhep_testdata.data_path("uproot-HZZ.root")) file = uproot.open(Path(tmp_path) / "drop_branches.root") + file["events"].show() for key in original["events"].keys(): if key.startswith("MClepton_"): assert key in file["events"].keys() @@ -81,6 +82,7 @@ def test_keep_branches(tmp_path): ) else: assert key not in file["events"].keys() + file.close() def test_hepdata_example(tmp_path):