Skip to content

Commit

Permalink
* bug fixed: params from base class can't be parsed in some OPs (#311)
Browse files Browse the repository at this point in the history
+ add a test case to test param parsing
  • Loading branch information
HYLcool authored May 8, 2024
1 parent 17c2c76 commit 6b6d9e0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 17 deletions.
6 changes: 5 additions & 1 deletion data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,19 @@ def _collect_config_info_from_class_docs(configurable_ops, parser):
:param configurable_ops: a list of ops to be added, each item is
a pair of op_name and op_class
:param parser: jsonargparse parser need to update
:return: all params of each OP in a dictionary
"""

op_params = {}
for op_name, op_class in configurable_ops:
parser.add_class_arguments(
params = parser.add_class_arguments(
theclass=op_class,
nested_key=op_name,
fail_untyped=False,
instantiate=False,
)
op_params[op_name] = params
return op_params


def sort_op_by_types_and_names(op_name_classes):
Expand Down
8 changes: 4 additions & 4 deletions data_juicer/ops/filter/image_face_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio

self.extra_kwargs = {
k: kwargs.get(k, v)
for k, v in self._default_kwargs.items()
}
self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
Expand Down
8 changes: 4 additions & 4 deletions data_juicer/ops/filter/video_motion_score_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def __init__(self,
self.max_score = max_score
self.sampling_fps = sampling_fps

self.extra_kwargs = {
k: kwargs.get(k, v)
for k, v in self._default_kwargs.items()
}
self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
Expand Down
8 changes: 4 additions & 4 deletions data_juicer/ops/mapper/image_face_blur_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(self,
self.blur_type = blur_type
self.radius = radius

self.extra_kwargs = {
k: kwargs.get(k, v)
for k, v in self._default_kwargs.items()
}
self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

# Initialize face detector
self.detector = dlib.get_frontal_face_detector()
Expand Down
8 changes: 4 additions & 4 deletions data_juicer/ops/mapper/video_face_blur_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self,
self.blur_type = blur_type
self.radius = radius

self.extra_kwargs = {
k: kwargs.get(k, v)
for k, v in self._default_kwargs.items()
}
self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

# Initialize face detector
self.detector = dlib.get_frontal_face_detector()
Expand Down
21 changes: 21 additions & 0 deletions tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,27 @@ def test_mixture_cfg(self):
}
})

def test_op_params_parsing(self):
from jsonargparse import ArgumentParser
from data_juicer.config.config import (sort_op_by_types_and_names, _collect_config_info_from_class_docs)
from data_juicer.ops.base_op import OPERATORS

base_class_params = {
'text_key', 'image_key', 'audio_key', 'video_key', 'accelerator',
'spec_numprocs', 'cpu_required', 'mem_required', 'use_actor',
}

parser = ArgumentParser(default_env=True, default_config_files=None)
ops_sorted_by_types = sort_op_by_types_and_names(
OPERATORS.modules.items())
op_params = _collect_config_info_from_class_docs(ops_sorted_by_types,
parser)

for op_name, params in op_params.items():
for base_param in base_class_params:
base_param_key = f'{op_name}.{base_param}'
self.assertIn(base_param_key, params)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6b6d9e0

Please sign in to comment.