diff --git a/src/coffea/jetmet_tools/CorrectedMETFactory.py b/src/coffea/jetmet_tools/CorrectedMETFactory.py index 5853dc5b2..25d979b84 100644 --- a/src/coffea/jetmet_tools/CorrectedMETFactory.py +++ b/src/coffea/jetmet_tools/CorrectedMETFactory.py @@ -4,11 +4,11 @@ def corrected_polar_met( - met_pt, met_phi, jet_pt, jet_phi, jet_pt_orig, positive=None, dx=None, dy=None + met_pt, met_phi, jet_pt, jet_phi, positive=None, dx=None, dy=None ): sj, cj = numpy.sin(jet_phi), numpy.cos(jet_phi) - x = met_pt * numpy.cos(met_phi) + awkward.sum((jet_pt - jet_pt_orig) * cj, axis=1) - y = met_pt * numpy.sin(met_phi) + awkward.sum((jet_pt - jet_pt_orig) * sj, axis=1) + x = met_pt * numpy.cos(met_phi) - awkward.sum(jet_pt * cj, axis=1) + y = met_pt * numpy.sin(met_phi) - awkward.sum(jet_pt * sj, axis=1) if positive is not None and dx is not None and dy is not None: x = x + dx if positive else x - dx y = y + dy if positive else y - dy @@ -36,7 +36,7 @@ def __init__(self, name_map): self.name_map = name_map - def build(self, in_MET, in_corrected_jets): + def build(self, in_MET, type1_MET, in_corrected_jets): if not isinstance( in_MET, (awkward.highlevel.Array, dask_awkward.Array) ) or not isinstance( @@ -60,7 +60,6 @@ def switch_properties(raw_met, corrected_jets, dx, dy, positive, save_orig): raw_met[self.name_map["METphi"]], corrected_jets[self.name_map["JetPt"]], corrected_jets[self.name_map["JetPhi"]], - corrected_jets[self.name_map["ptRaw"]], positive=positive, dx=dx, dy=dy, @@ -144,10 +143,10 @@ def create_variants(raw_met, corrected_jets_or_variants, dx, dy): out_dict["MET_UnclusteredEnergy"] = dask_awkward.map_partitions( create_variants, - MET, + type1_MET, corrected_jets, - MET[self.name_map["UnClusteredEnergyDeltaX"]], - MET[self.name_map["UnClusteredEnergyDeltaY"]], + type1_MET[self.name_map["UnClusteredEnergyDeltaX"]], + type1_MET[self.name_map["UnClusteredEnergyDeltaY"]], label="UnclusteredEnergy_met", ) diff --git a/tests/test_jetmet_tools.py b/tests/test_jetmet_tools.py index 3443f4837..61297fdbb 100644 --- a/tests/test_jetmet_tools.py +++ b/tests/test_jetmet_tools.py @@ -919,12 +919,12 @@ def smear_factor(jetPt, pt_gen, jersf): toc = time.time() print("setup corrected MET time =", toc - tic) - + rawmet = events.RawMET met = events.MET tic = time.time() # prof = pyinstrument.Profiler() # prof.start() - corrected_met = met_factory.build(met, corrected_jets) + corrected_met = met_factory.build(rawmet, met, corrected_jets) # prof.stop() toc = time.time()