Skip to content

Commit

Permalink
Merge pull request #40 from LSSTDESC/u/jrbogart/rowgroup_bug
Browse files Browse the repository at this point in the history
fix problem with row groups when writing galaxy flux files
  • Loading branch information
JoanneBogart authored Jan 17, 2023
2 parents adde380 + f71bb4e commit ff1b744
Showing 1 changed file with 74 additions and 81 deletions.
155 changes: 74 additions & 81 deletions python/desc/skycatalogs/catalog_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, l_bnd, u_bnd):
o_list = galaxy_collection[l_bnd : u_bnd]
out_dict['galaxy_id'] = [o.get_native_attribute('galaxy_id') for o in o_list]
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
###all_fluxes = [galaxy_collection[i].get_LSST_fluxes(as_dict=False) for i in range(l_bnd, u_bnd)]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
v = all_fluxes_transpose.__next__()
Expand Down Expand Up @@ -253,8 +254,6 @@ def create_galaxy_catalog(self):
# Save cosmology in case we need to write parameters out later
self._cosmology = gal_cat.cosmology

###self._mag_norm_f = MagNorm(self._cosmology)

arrow_schema = make_galaxy_schema(self._logname,
self._sed_subdir,
self._knots)
Expand Down Expand Up @@ -295,7 +294,6 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema):
tophat_disk_re = r'sed_(?P<start>\d+)_(?P<width>\d+)_disk'

# Number of rows to include in a row group
#stride = 1000000
stride = self._galaxy_stride

hp_filter = [f'healpix_pixel=={pixel}']
Expand All @@ -306,7 +304,6 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema):
# to_fetch = all columns of interest in gal_cat
non_sed = ['galaxy_id', 'ra', 'dec', 'redshift', 'redshiftHubble',
'peculiarVelocity', 'shear_1', 'shear_2',
#'convergence', 'position_angle_true',
'convergence',
'size_bulge_true', 'size_minor_bulge_true', 'sersic_bulge',
'size_disk_true', 'size_minor_disk_true', 'sersic_disk']
Expand Down Expand Up @@ -479,7 +476,6 @@ def create_galaxy_flux_catalog(self, config_file=None):
self._gal_flux_schema = make_galaxy_flux_schema(self._logname)

if not config_file:
#config_file = self._written_config
config_file = self.write_config(path_only=True)
if not self._cat:
self._cat = open_catalog(config_file,
Expand Down Expand Up @@ -526,91 +522,88 @@ def _create_galaxy_flux_pixel(self, pixel):
# Use same value for row group as was used for main file
stride = self._galaxy_stride

# If there are multiple row groups, each is stored in a separate
# object collection. Need to loop over them
object_list = self._cat.get_objects_by_hp(pixel,
obj_type_set={'galaxy'})
object_coll = object_list.get_collections()[0]
writer = None
l_bnd = 0
u_bnd = min(len(object_coll), stride)
rg_written = 0
while u_bnd > l_bnd:
o_list = object_coll[l_bnd : u_bnd]
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}')
global _galaxy_collection

for object_coll in object_list.get_collections():
_galaxy_collection = object_coll
# prefetch everything we need. Getting a quantity for one object
# ensures the quantity is read for the whole row group
# (In fact currently all row groups are read, but that could and
# should change)
for att in ['galaxy_id', 'shear_1', 'shear_2', 'convergence',
'redshift_hubble', 'MW_av', 'MW_rv', 'sed_val_bulge',
'sed_val_disk', 'sed_val_knots'] :
v = o_list[0].get_native_attribute(att)

global _galaxy_collection
_galaxy_collection = object_coll

out_dict = {'galaxy_id': [], 'lsst_flux_u' : [],
'lsst_flux_g' : [], 'lsst_flux_r' : [],
'lsst_flux_i' : [], 'lsst_flux_z' : [],
'lsst_flux_y' : []}

n_parallel = self._flux_parallel

if n_parallel == 1:
n_per = u_bnd - l_bnd
else:
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel)
l = l_bnd
u = min(l_bnd + n_per, u_bnd)
readers = []

if n_parallel == 1:
out_dict = _do_galaxy_flux_chunk(None, _galaxy_collection,
l, u)
else:
# Expect to be able to do about 1500/minute/process
tm = int((n_per*60)/500) # Give ourselves a cushion
self._logger.info(f'Using timeout value {tm} for {n_per} sources')
p_list = []
for i in range(n_parallel):
conn_rd, conn_wrt = Pipe(duplex=False)
readers.append(conn_rd)

# For debugging call directly
proc = Process(target=_do_galaxy_flux_chunk,
name=f'proc_{i}',
args=(conn_wrt, _galaxy_collection,l, u))
proc.start()
p_list.append(proc)
l = u
u = min(l + n_per, u_bnd)

self._logger.debug('Processes started')

for i in range(n_parallel):
ready = readers[i].poll(tm)
if not ready:
self._logger.error(f'Process {i} timed out after {tm} sec')
sys.exit(1)
dat = readers[i].recv()
for k in ['galaxy_id', 'lsst_flux_u', 'lsst_flux_g',
'lsst_flux_r', 'lsst_flux_i', 'lsst_flux_z',
'lsst_flux_y']:
out_dict[k] += dat[k]
for p in p_list:
p.join()

out_df = pd.DataFrame.from_dict(out_dict)
out_table = pa.Table.from_pandas(out_df,
schema=self._gal_flux_schema)

if not writer:
writer = pq.ParquetWriter(output_path, self._gal_flux_schema)
writer.write_table(out_table)
l_bnd = u_bnd
u_bnd = min(l_bnd + stride, len(object_coll))

rg_written +=1
v = object_coll.get_native_attribute(att)
l_bnd = 0
u_bnd = min(len(object_coll), stride)
rg_written = 0
while u_bnd > l_bnd:
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}')

out_dict = {'galaxy_id': [], 'lsst_flux_u' : [],
'lsst_flux_g' : [], 'lsst_flux_r' : [],
'lsst_flux_i' : [], 'lsst_flux_z' : [],
'lsst_flux_y' : []}

n_parallel = self._flux_parallel

if n_parallel == 1:
n_per = u_bnd - l_bnd
else:
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel)
l = l_bnd
u = min(l_bnd + n_per, u_bnd)
readers = []

if n_parallel == 1:
out_dict = _do_galaxy_flux_chunk(None, _galaxy_collection,
l, u)
else:
# Expect to be able to do about 1500/minute/process
tm = int((n_per*60)/500) # Give ourselves a cushion
self._logger.info(f'Using timeout value {tm} for {n_per} sources')
p_list = []
for i in range(n_parallel):
conn_rd, conn_wrt = Pipe(duplex=False)
readers.append(conn_rd)

# For debugging call directly
proc = Process(target=_do_galaxy_flux_chunk,
name=f'proc_{i}',
args=(conn_wrt, _galaxy_collection,l, u))
proc.start()
p_list.append(proc)
l = u
u = min(l + n_per, u_bnd)

self._logger.debug('Processes started') # outside for loop
for i in range(n_parallel):
ready = readers[i].poll(tm)
if not ready:
self._logger.error(f'Process {i} timed out after {tm} sec')
sys.exit(1)
dat = readers[i].recv() # lines up with if
for k in ['galaxy_id', 'lsst_flux_u', 'lsst_flux_g',
'lsst_flux_r', 'lsst_flux_i', 'lsst_flux_z',
'lsst_flux_y']:
out_dict[k] += dat[k]
for p in p_list:
p.join()

out_df = pd.DataFrame.from_dict(out_dict) # outdent from for
out_table = pa.Table.from_pandas(out_df,
schema=self._gal_flux_schema)

if not writer:
writer = pq.ParquetWriter(output_path, self._gal_flux_schema)
writer.write_table(out_table)
l_bnd = u_bnd
u_bnd = min(l_bnd + stride, len(object_coll))

rg_written +=1

writer.close()
self._logger.debug(f'# row groups written to flux file: {rg_written}')
Expand Down

0 comments on commit ff1b744

Please sign in to comment.