-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds force_pointing
kw to match_params
#924
Changes from 4 commits
764737a
f0ef44d
48022ff
d38dd56
a4e7c49
7203bb3
a8a6b7c
ef0d57e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
// See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. | ||
// Extension identifier format: ${publisher}.${name}. Example: vscode.csharp | ||
|
||
// List of extensions which should be recommended for users of this workspace. | ||
"recommendations": [ | ||
|
||
], | ||
// List of extensions recommended by VS Code that should not be recommended for users of this workspace. | ||
"unwantedRecommendations": [ | ||
|
||
], | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
{ | ||
"[restructuredtext]": { | ||
"editor.quickSuggestions": { | ||
"other": "off", | ||
"comments": "off", | ||
"strings": "off" | ||
} | ||
}, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import matplotlib.pyplot as plt | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
from sotodlib.core import AxisManager | ||
from sotodlib.coords import optics | ||
from sotodlib.core import metadata | ||
from sotodlib.io.metadata import write_dataset | ||
|
@@ -263,7 +264,7 @@ def as_array(self): | |
return np.array(data, dtype=dtype) | ||
|
||
@classmethod | ||
def from_aman(cls, aman, stream_id, det_cal=None, name=None): | ||
def from_aman(cls, aman, stream_id, det_cal=None, name=None, pointing: Optional[AxisManager]=None): | ||
""" | ||
Load a resonator set from a Context object based on an obs_id | ||
|
||
|
@@ -276,6 +277,10 @@ def from_aman(cls, aman, stream_id, det_cal=None, name=None): | |
det_cal: AxisManager | ||
Detector calibration metadata. If not specified, will default to | ||
``aman.det_cal`` | ||
pointing: Optional[AxisManager] | ||
AxisManager containing pointing metadata. If set, this should be an | ||
AxisManager containing the fields ``xi`` and ``eta``, and resonator | ||
pointing information will be added from here. | ||
""" | ||
m = aman.det_info.stream_id == stream_id | ||
if not np.any(m): | ||
|
@@ -300,6 +305,10 @@ def from_aman(cls, aman, stream_id, det_cal=None, name=None): | |
idx=i, is_north=is_north, res_freq=res_freq, smurf_band=band, | ||
smurf_channel=channel, readout_id=readout_id, bg=bg | ||
) | ||
if pointing is not None: | ||
res.xi = pointing.xi[ri] | ||
res.eta = pointing.eta[ri] | ||
|
||
resonators.append(res) | ||
|
||
return cls(resonators, name=name) | ||
|
@@ -548,6 +557,17 @@ def add_pointing(self, bands, chans, xis, etas): | |
r.xi = xis[idx] | ||
r.eta = etas[idx] | ||
|
||
|
||
def get_det_type_mask(arr: np.ndarray, det_types: List[str]) -> np.ndarray: | ||
""" | ||
Returns a boolean mask of all dets that have specific detector types. | ||
""" | ||
return np.logical_or.reduce([ | ||
arr['det_type'] == dt | ||
for dt in det_types | ||
]) | ||
|
||
|
||
@dataclass | ||
class MatchParams: | ||
""" | ||
|
@@ -567,9 +587,9 @@ class MatchParams: | |
penalty to apply to leaving a resonator with a good qi unassigned | ||
good_res_qi_thresh (float): | ||
qi threshold that is considered "good" | ||
force_src_pointing (bool): | ||
If true, will assign a np.inf penalty to leaving a src resonator | ||
with a provided pointing unmatched. | ||
enforce_pointing_reqs (bool): | ||
If True, will enforce pointing requirements that depend on resonator | ||
type. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either here, or somewhere else, explain in words what those requirements are (i.e. OPTC must have a position match, DARK/SLOT may have a position match, and everything else must not). |
||
assigned_bg_unmatched_pen (float): | ||
Penalty to apply to leaving a resonator with an assigned bg | ||
unmatched | ||
|
@@ -589,7 +609,8 @@ class MatchParams: | |
dist_width: float =0.01 | ||
unmatched_good_res_pen: float = 10. | ||
good_res_qi_thresh: float = 100e3 | ||
force_src_pointing: bool = False | ||
enforce_pointing_reqs: bool = False | ||
|
||
assigned_bg_unmatched_pen: float = 100000 | ||
unassigned_bg_unmatched_pen: float = 10000 | ||
assigned_bg_mismatch_pen: float = 100000 | ||
|
@@ -615,10 +636,9 @@ class Match: | |
difference between `src` and `dst` res-sets, except: | ||
|
||
- When merged, smurf-data such as band, channel, and res-idx will be taken | ||
from the ``src`` res-set | ||
- The ``force_src_pointing`` param can be used to assign a very high penalty | ||
to leaving any `src` resonator that has pointing info unassigned. | ||
|
||
from the ``src`` res-set, while detector information will be taken from | ||
the ``dst`` set. | ||
|
||
Args: | ||
src (ResSet): | ||
The source resonator set | ||
|
@@ -669,7 +689,7 @@ def __init__(self, src: ResSet, dst: ResSet, match_pars: Optional[MatchParams]=N | |
self.matching, self.merged = self._match() | ||
self.stats = self.get_stats() | ||
|
||
def _get_biadjacency_matrix(self): | ||
def _get_biadjacency_matrix(self) -> np.ndarray: | ||
src_arr = self.src.as_array() | ||
dst_arr = self.dst.as_array() | ||
|
||
|
@@ -679,6 +699,29 @@ def _get_biadjacency_matrix(self): | |
m = src_arr['is_north'][:, None] != dst_arr['is_north'][None, :] | ||
mat[m] = np.inf | ||
|
||
if self.match_pars.enforce_pointing_reqs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So in this block, 'DARK' and 'SLOT' are implicitly untouched. I think you need to at least comment on DARK/SLOT, if not check explicitly that everything you haven't directly modified is listed as DARK or SLOT. |
||
src_has_pointing = np.isfinite(src_arr['xi']) & np.isfinite(src_arr['eta']) | ||
dst_has_pointing = np.isfinite(dst_arr['xi']) & np.isfinite(dst_arr['eta']) | ||
|
||
src_no_match = get_det_type_mask(src_arr, ['NC']) | ||
dst_no_match = get_det_type_mask(dst_arr, ['NC']) | ||
mat[src_no_match, :] = np.inf | ||
mat[:, dst_no_match] = np.inf | ||
|
||
src_pointing_forbidden = get_det_type_mask(src_arr, ['UNRT', 'SQID', 'BARE']) | ||
dst_pointing_forbidden = get_det_type_mask(dst_arr, ['UNRT', 'SQID', 'BARE']) | ||
m = src_pointing_forbidden[:, None] & dst_has_pointing[None, :] | ||
mat[m] = np.inf | ||
m = src_has_pointing[:, None] & dst_pointing_forbidden[None, :] | ||
mat[m] = np.inf | ||
|
||
src_pointing_required = get_det_type_mask(src_arr, ['OPTC']) | ||
dst_pointing_required = get_det_type_mask(dst_arr, ['OPTC']) | ||
m = src_pointing_required[:, None] & (~dst_has_pointing[None, :]) | ||
mat[m] = np.inf | ||
m = (~src_has_pointing[:, None]) & dst_pointing_required[None, :] | ||
mat[m] = np.inf | ||
|
||
# Frequency offset | ||
df = src_arr['res_freq'][:, None] - dst_arr['res_freq'][None, :] | ||
df -= self.match_pars.freq_offset_mhz | ||
|
@@ -705,7 +748,7 @@ def _get_biadjacency_matrix(self): | |
|
||
return mat | ||
|
||
def _get_unassigned_costs(self, rs, force_if_pointing=True): | ||
def _get_unassigned_costs(self, rs): | ||
ra = rs.as_array() | ||
|
||
arr = np.zeros(len(rs)) | ||
|
@@ -719,11 +762,6 @@ def _get_unassigned_costs(self, rs, force_if_pointing=True): | |
arr[bg_assigned] += self.match_pars.assigned_bg_unmatched_pen | ||
arr[~bg_assigned] += self.match_pars.unassigned_bg_unmatched_pen | ||
|
||
# Infinite cost if has pointing | ||
if force_if_pointing: | ||
m = ~np.isnan(ra['xi']) | ||
arr[m] = np.inf | ||
|
||
return arr | ||
|
||
|
||
|
@@ -741,12 +779,9 @@ def _match(self): | |
|
||
mat_full[:len(self.src), :len(self.dst)] = self._get_biadjacency_matrix() | ||
mat_full[:len(self.src), len(self.dst):] = \ | ||
self._get_unassigned_costs( | ||
self.src, | ||
force_if_pointing=self.match_pars.force_src_pointing | ||
)[:, None] | ||
self._get_unassigned_costs(self.src)[:, None] | ||
mat_full[len(self.src):, :len(self.dst)] = \ | ||
self._get_unassigned_costs(self.dst, force_if_pointing=False)[None, :] | ||
self._get_unassigned_costs(self.dst)[None, :] | ||
mat_full[len(self.src):, len(self.dst):] = 0 | ||
|
||
self.matching = np.array(linear_sum_assignment(mat_full)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are looking for
np.isin
!