diff --git a/pyproject.toml b/pyproject.toml index a91b3a0..7ae40c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,10 @@ lint.ignore = [ "INP001", "I001", "FA100", + "FA102", + "SIM118", + "UP031", + "PT011", ] lint.select = ["ALL"] exclude = ["migrations"] diff --git a/src/django_viewcomponent/component.py b/src/django_viewcomponent/component.py index af74bd2..0c9796c 100644 --- a/src/django_viewcomponent/component.py +++ b/src/django_viewcomponent/component.py @@ -63,7 +63,7 @@ def get_template(self) -> Template: raise ImproperlyConfigured( f"Either 'template_name' or 'template' must be set for Component {type(self).__name__}." - f"Note: this attribute is not required if you are overriding the class's `get_template*()` methods." + f"Note: this attribute is not required if you are overriding the class's `get_template*()` methods.", ) def prepare_context( diff --git a/src/django_viewcomponent/component_registry.py b/src/django_viewcomponent/component_registry.py index 58fe3b3..4e73618 100644 --- a/src/django_viewcomponent/component_registry.py +++ b/src/django_viewcomponent/component_registry.py @@ -6,7 +6,7 @@ class NotRegistered(Exception): pass -class ComponentRegistry(object): +class ComponentRegistry: def __init__(self): self._registry = {} # component name -> component_class mapping @@ -14,7 +14,7 @@ def register(self, name=None, component=None): existing_component = self._registry.get(name) if existing_component and existing_component.class_hash != component.class_hash: raise AlreadyRegistered( - 'The component "%s" has already been registered' % name + 'The component "%s" has already been registered' % name, ) self._registry[name] = component diff --git a/src/django_viewcomponent/fields.py b/src/django_viewcomponent/fields.py index 8b97771..026cf78 100644 --- a/src/django_viewcomponent/fields.py +++ b/src/django_viewcomponent/fields.py @@ -1,18 +1,16 @@ -from typing import Optional - from django_viewcomponent.component_registry import registry as component_registry class FieldValue: def __init__( self, + content: str, dict_data: dict, - component: Optional[str] = None, + component: None, parent_component=None, - **kwargs, ): + self._content = content or "" self._dict_data = dict_data - self._content = self._dict_data.pop("content", "") self._component = component self._parent_component = parent_component @@ -26,18 +24,41 @@ def __str__(self): def render(self): from django_viewcomponent.component import Component - component_cls = None if isinstance(self._component, str): - component_cls = component_registry.get(self._component) - elif issubclass(self._component, Component): - component_cls = self._component + return self._render_for_component_cls( + component_registry.get(self._component), + ) + elif not isinstance(self._component, type) and callable(self._component): + # self._component is function + callable_component = self._component + result = callable_component(**self._dict_data) + + if isinstance(result, str): + return result + elif isinstance(result, Component): + # render component instance + return self._render_for_component_instance(result) + else: + raise ValueError( + f"Callable slot component must return str or Component instance. Got {result}", + ) + elif isinstance(self._component, type) and issubclass( + self._component, + Component, + ): + # self._component is Component class + return self._render_for_component_cls(self._component) else: - raise ValueError(f"Invalid component type {self._component}") + raise ValueError(f"Invalid component variable {self._component}") + def _render_for_component_cls(self, component_cls): component = component_cls( **self._dict_data, ) - component.component_name = self._component + + return self._render_for_component_instance(component) + + def _render_for_component_instance(self, component): component.component_context = self._parent_component.component_context with component.component_context.push(): @@ -86,13 +107,9 @@ def handle_call(self, content, **kwargs): class RendersOneField(BaseSlotField): def handle_call(self, content, **kwargs): - value_dict = { - "content": content, - "parent_component": self.parent_component, - **kwargs, - } value_instance = FieldValue( - dict_data=value_dict, + content=content, + dict_data={**kwargs}, component=self._component, parent_component=self.parent_component, ) @@ -103,13 +120,9 @@ def handle_call(self, content, **kwargs): class RendersManyField(BaseSlotField): def handle_call(self, content, **kwargs): - value_dict = { - "content": content, - "parent_component": self.parent_component, - **kwargs, - } value_instance = FieldValue( - dict_data=value_dict, + content=content, + dict_data={**kwargs}, component=self._component, parent_component=self.parent_component, ) diff --git a/src/django_viewcomponent/preview.py b/src/django_viewcomponent/preview.py index 18b348a..7cb9ffc 100644 --- a/src/django_viewcomponent/preview.py +++ b/src/django_viewcomponent/preview.py @@ -42,7 +42,8 @@ def __init_subclass__(cls, **kwargs): cls.preview_name = new_name cls.preview_view_component_path = os.path.abspath(inspect.getfile(cls)) cls.url = urljoin( - reverse("django_viewcomponent:preview-index"), cls.preview_name + "/" + reverse("django_viewcomponent:preview-index"), + cls.preview_name + "/", ) @classmethod diff --git a/src/django_viewcomponent/templatetags/viewcomponent_tags.py b/src/django_viewcomponent/templatetags/viewcomponent_tags.py index 015bdda..15e52f9 100644 --- a/src/django_viewcomponent/templatetags/viewcomponent_tags.py +++ b/src/django_viewcomponent/templatetags/viewcomponent_tags.py @@ -73,14 +73,14 @@ def render(self, context): if "content" in resolved_kwargs: raise ValueError( - "The 'content' kwarg is reserved and cannot be passed in component call tag" + "The 'content' kwarg is reserved and cannot be passed in component call tag", ) resolved_kwargs["content"] = content component_token, field_token = self.args[0].token.split(".") component_instance = FilterExpression(component_token, self.parser).resolve( - context + context, ) if not component_instance: raise ValueError(f"Component {component_token} not found in context") @@ -88,7 +88,7 @@ def render(self, context): field = getattr(component_instance, field_token, None) if not field: raise ValueError( - f"Field {field_token} not found in component {component_token}" + f"Field {field_token} not found in component {component_token}", ) if isinstance(field, BaseSlotField): @@ -115,7 +115,9 @@ def __repr__(self): return "" % ( self.name_fexp, getattr( - self, "nodelist", None + self, + "nodelist", + None, ), # 'nodelist' attribute only assigned later. ) @@ -185,7 +187,9 @@ def do_component(parser, token): bits = bits[:-2] component_name, context_args, context_kwargs = parse_component_with_arguments( - parser, bits, "component" + parser, + bits, + "component", ) nodelist: NodeList = parser.parse(parse_until=["endcomponent"]) parser.delete_first_token() @@ -217,7 +221,7 @@ def parse_component_with_arguments(parser, bits, tag_name): if tag_name != tag_args[0].token: raise RuntimeError( - f"Internal error: Expected tag_name to be {tag_name}, but it was {tag_args[0].token}" + f"Internal error: Expected tag_name to be {tag_name}, but it was {tag_args[0].token}", ) if len(tag_args) > 1: # At least one position arg, so take the first as the component name @@ -226,7 +230,7 @@ def parse_component_with_arguments(parser, bits, tag_name): context_kwargs = tag_kwargs else: raise TemplateSyntaxError( - f"Call the '{tag_name}' tag with a component name as the first parameter" + f"Call the '{tag_name}' tag with a component name as the first parameter", ) return component_name, context_args, context_kwargs diff --git a/tests/conftest.py b/tests/conftest.py index 499d3b5..d181dfb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,13 +27,13 @@ def pytest_configure(): "django.template.loaders.app_directories.Loader", "django_viewcomponent.loaders.ComponentLoader", ], - ) + ), ], "builtins": [ "django_viewcomponent.templatetags.viewcomponent_tags", ], }, - } + }, ], INSTALLED_APPS=[ "django.contrib.admin", diff --git a/tests/previews/simple_preview.py b/tests/previews/simple_preview.py index ff30429..595b2f1 100644 --- a/tests/previews/simple_preview.py +++ b/tests/previews/simple_preview.py @@ -35,7 +35,7 @@ def with_template_render(self, title="default title", **kwargs): {% load viewcomponent_tags %} {% component "example" title=title %} {% endcomponent %} - """ + """, ) # pass the title from the URL querystring to the context diff --git a/tests/test_layout.py b/tests/test_layout.py index dd77606..4c81b4e 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -6,13 +6,13 @@ class TestLayoutComponents: def test_html(self): html = HTML("{% if saved %}Data saved{% endif %}").render_from_parent_context( - {"saved": True} + {"saved": True}, ) assert "Data saved" in html # step_field and step0 not defined html = HTML( - '' + '', ).render_from_parent_context() assert_select(html, "input") diff --git a/tests/test_preview.py b/tests/test_preview.py index 150ecd2..2f1f97e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -35,7 +35,7 @@ def test_previews(self, client): reverse( "django_viewcomponent:previews", kwargs={"preview_name": "simple_example_component"}, - ) + ), ) assert response.status_code == 200 @@ -49,7 +49,7 @@ def test_preview(self, client): "preview_name": "simple_example_component", "example_name": "with_title", }, - ) + ), ) assert response.status_code == 200 diff --git a/tests/test_render_field.py b/tests/test_render_field.py index 76b38c8..e9c0117 100644 --- a/tests/test_render_field.py +++ b/tests/test_render_field.py @@ -1,5 +1,6 @@ import pytest from django.template import Context, Template +from django.utils.safestring import mark_safe from django_viewcomponent import component from django_viewcomponent.fields import RendersManyField, RendersOneField @@ -10,9 +11,7 @@ @pytest.mark.django_db class TestRenderFieldComponentParameterString: """ - test setting component using component string - - RendersOneField(required=True, component="header") + component parameter is a component name string """ class HeaderComponent(component.Component): @@ -73,7 +72,7 @@ def test_field_component_parameter(self): {% call component.posts post=post %}{% endcall %} {% endfor %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({"qs": qs})) expected = """ @@ -136,7 +135,7 @@ def __init__(self, post, **kwargs): @pytest.mark.django_db class TestRenderFieldComponentParameterClass: """ - test setting component using component Class + component parameter is a Component class """ @pytest.fixture(autouse=True) @@ -162,7 +161,7 @@ def test_field_component_parameter(self): {% call component.posts post=post %}{% endcall %} {% endfor %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({"qs": qs})) expected = """ @@ -189,14 +188,45 @@ def test_field_component_parameter(self): @pytest.mark.django_db -class TestRenderFieldComponentParameterLambda: +class TestRenderFieldComponentParameterLambdaReturnString: """ - test setting component using component Class + component parameter is a lambda that returns a string """ + class HeaderComponent(component.Component): + def __init__(self, classes, **kwargs): + self.classes = classes + + template = """ +

+ {{ self.content }} +

+ """ + + class BlogComponent(component.Component): + header = RendersOneField(required=True, component="header") + posts = RendersManyField( + required=True, + component=lambda post, **kwargs: mark_safe( + f""" +

{post.title}

+
{post.description}
+ """, + ), + ) + + template = """ + {% load viewcomponent_tags %} + {{ self.header.value }} + {% for post in self.posts.value %} + {{ post }} + {% endfor %} + """ + @pytest.fixture(autouse=True) def register_component(self): - component.registry.register("blog", BlogComponent) + component.registry.register("blog", self.BlogComponent) + component.registry.register("header", self.HeaderComponent) def test_field_component_parameter(self): for i in range(5): @@ -217,7 +247,100 @@ def test_field_component_parameter(self): {% call component.posts post=post %}{% endcall %} {% endfor %} {% endcomponent %} + """, + ) + rendered = template.render(Context({"qs": qs})) + expected = """ +

+ My Site +

+ +

test 0

+
test 0
+ +

test 1

+
test 1
+ +

test 2

+
test 2
+ +

test 3

+
test 3
+ +

test 4

+
test 4
+ """ + assert_dom_equal(expected, rendered) + + +class PostComponent(component.Component): + def __init__(self, post, **kwargs): + self.post = post + + template = """ + {% load viewcomponent_tags %} + +

{{ self.post.title }}

+
{{ self.post.description }}
+ """ + + +@pytest.mark.django_db +class TestRenderFieldComponentParameterLambdaReturnInstance: + """ + component parameter is a lambda that returns a component instance + """ + + class HeaderComponent(component.Component): + def __init__(self, classes, **kwargs): + self.classes = classes + + template = """ +

+ {{ self.content }} +

+ """ + + class BlogComponent(component.Component): + header = RendersOneField(required=True, component="header") + posts = RendersManyField( + required=True, + component=lambda post: PostComponent(post=post), + ) + + template = """ + {% load viewcomponent_tags %} + {{ self.header.value }} + {% for post in self.posts.value %} + {{ post }} + {% endfor %} + """ + + @pytest.fixture(autouse=True) + def register_component(self): + component.registry.register("blog", self.BlogComponent) + component.registry.register("header", self.HeaderComponent) + + def test_field_component_parameter(self): + for i in range(5): + title = f"test {i}" + description = f"test {i}" + Post.objects.create(title=title, description=description) + + qs = Post.objects.all() + + template = Template( """ + {% load viewcomponent_tags %} + {% component 'blog' as component %} + {% call component.header classes='text-lg' %} + My Site + {% endcall %} + {% for post in qs %} + {% call component.posts post=post %}{% endcall %} + {% endfor %} + {% endcomponent %} + """, ) rendered = template.render(Context({"qs": qs})) expected = """ diff --git a/tests/test_tags.py b/tests/test_tags.py index 54143e2..6d816a9 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -92,7 +92,8 @@ def test_single_component(self): template = Template(simple_tag_template) assert_dom_equal( - "Variable: variable", template.render(Context()) + "Variable: variable", + template.render(Context()), ) def test_call_with_invalid_name(self): @@ -210,7 +211,7 @@ def test_slotted_template_basic(self): {% component "test2" variable="variable" %}{% endcomponent %} {% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) @@ -243,7 +244,7 @@ def test_slotted_template_with_includes(self): {% component "test2" variable="variable" %}{% endcomponent %} {% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) @@ -274,7 +275,7 @@ def test_slotted_template_with_context_var(self): {% endcall %} {% endcomponent %} {% endwith %} - """ + """, ) rendered = template.render(Context({"my_second_variable": "test321"})) @@ -291,13 +292,14 @@ def test_slotted_template_with_context_var(self): def test_slotted_template_that_uses_missing_variable(self): component.registry.register( - name="test", component=SlottedComponentWithMissingVariable + name="test", + component=SlottedComponentWithMissingVariable, ) template = Template( """ {% load viewcomponent_tags %} {% component 'test' %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) @@ -316,7 +318,7 @@ def test_slotted_template_no_slots_filled(self): component.registry.register(name="test", component=SlottedComponent) template = Template( - '{% load viewcomponent_tags %}{% component "test" %}{% endcomponent %}' + '{% load viewcomponent_tags %}{% component "test" %}{% endcomponent %}', ) rendered = template.render(Context({})) @@ -337,7 +339,7 @@ def test_slotted_template_without_slots(self): """ {% load viewcomponent_tags %} {% component "test" %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) @@ -349,7 +351,7 @@ def test_slotted_template_without_slots_and_single_quotes(self): """ {% load viewcomponent_tags %} {% component 'test' %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) @@ -367,7 +369,7 @@ class Component(component.Component): {% load viewcomponent_tags %} {% component 'test' %} {% endcomponent %} - """ + """, ) with pytest.raises(ValueError): template.render(Context({})) @@ -391,7 +393,7 @@ class Component(component.Component): {% call component_2.footer %}{% endcall %} {% endcomponent %} {% endcomponent %} - """ + """, ) expected = """
@@ -419,17 +421,17 @@ def make_template(self, first_component_slot="", second_component_slot=""): + "{% endcomponent %}" "{% component 'second_component' variable='xyz' as component_2 %}" + second_component_slot - + "{% endcomponent %}" + + "{% endcomponent %}", ) def expected_result(self, first_component_slot="", second_component_slot=""): return ( "
{}
".format( - first_component_slot or "Default header" + first_component_slot or "Default header", ) + "
Default main
Default footer
" + "
{}
".format( - second_component_slot or "Default header" + second_component_slot or "Default header", ) + "
Default main
Default footer
" ) @@ -448,11 +450,13 @@ def test_both_components_render_correctly_with_slots(self): second_slot_content = "
Slot #2
" first_slot = self.wrap_with_slot_tags("component_1.header", first_slot_content) second_slot = self.wrap_with_slot_tags( - "component_2.header", second_slot_content + "component_2.header", + second_slot_content, ) rendered = self.make_template(first_slot, second_slot).render(Context({})) assert_dom_equal( - self.expected_result(first_slot_content, second_slot_content), rendered + self.expected_result(first_slot_content, second_slot_content), + rendered, ) def test_both_components_render_correctly_when_only_first_has_slots(self): @@ -466,7 +470,8 @@ def test_both_components_render_correctly_when_only_second_has_slots(self): self.register_components() second_slot_content = "
Slot #2
" second_slot = self.wrap_with_slot_tags( - "component_2.header", second_slot_content + "component_2.header", + second_slot_content, ) rendered = self.make_template("", second_slot).render(Context({})) assert_dom_equal(self.expected_result("", second_slot_content), rendered) @@ -486,7 +491,7 @@ def test_default_slot_contents_render_correctly(self): """ {% load viewcomponent_tags %} {% component 'test' %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, '
Default
') @@ -498,7 +503,7 @@ def test_inner_slot_overriden(self): """ {% load viewcomponent_tags %} {% component 'test' as component %}{% call component.inner %}Override{% endcall %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, '
Override
') @@ -510,7 +515,7 @@ def test_outer_slot_overriden(self): """ {% load viewcomponent_tags %} {% component 'test' as component %}{% call component.outer %}

Override

{% endcall %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, "

Override

") @@ -525,7 +530,7 @@ def test_both_overriden_and_inner_removed(self): {% call component.outer %}

Override

{% endcall %} {% call component.inner %}

Will not appear

{% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, "

Override

") @@ -551,7 +556,7 @@ def test_no_content_if_branches_are_false(self): {% call component.slot_a %}Override A{% endcall %} {% call component.slot_b %}Override B{% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, "") @@ -564,7 +569,7 @@ def test_default_content_if_no_slots(self): {% load viewcomponent_tags %} {% component 'test' branch='a' %}{% endcomponent %} {% component 'test' branch='b' %}{% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, '

Default A

Default B

') @@ -581,7 +586,7 @@ def test_one_slot_overridden(self): {% component 'test' branch='b' as component_2 %} {% call component_2.slot_b %}Override B{% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, '

Default A

Override B

') @@ -600,7 +605,7 @@ def test_both_slots_overridden(self): {% call component_2.slot_a %}Override A{% endcall %} {% call component_2.slot_b %}Override B{% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal(rendered, '

Override A

Override B

') @@ -618,7 +623,7 @@ def test_variable(self): {% component "test" %} {{ anything }} {% endcomponent %} - """ + """, ) def test_text(self): @@ -628,7 +633,7 @@ def test_text(self): {% component "test" %} Text {% endcomponent %} - """ + """, ) def test_block_outside_call(self): @@ -640,7 +645,7 @@ def test_block_outside_call(self): {% call component.header %}{% endcall %} {% endif %} {% endcomponent %} - """ + """, ) def test_unclosed_component_is_error(self): @@ -650,7 +655,7 @@ def test_unclosed_component_is_error(self): {% load viewcomponent_tags %} {% component "test" %} {% call "header" %}{% endcall %} - """ + """, ) def test_fill_with_no_component_is_error(self): @@ -659,7 +664,7 @@ def test_fill_with_no_component_is_error(self): """ {% load viewcomponent_tags %} {% call component.header %}contents{% endcall %} - """ + """, ).render(Context({})) @@ -678,7 +683,7 @@ def test_component_nesting_component_without_fill(self): Hello, User X {% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({"items": [1, 2, 3]})) expected = """ @@ -718,7 +723,8 @@ class ComponentWithComplexConditionalSlots(component.Component): @pytest.fixture(autouse=True) def register_component(self): component.registry.register( - "conditional_slots", self.ComponentWithConditionalSlots + "conditional_slots", + self.ComponentWithConditionalSlots, ) component.registry.register( "complex_conditional_slots", @@ -867,7 +873,7 @@ def test_missing_required_slot_raises_error(self): {% load viewcomponent_tags %} {% component 'tabs' %} {% endcomponent %} - """ + """, ) with pytest.raises(ValueError): template.render(Context({})) @@ -885,7 +891,7 @@ def test_collection_basic(self): {% call component.panels %}Panel 2{% endcall %} {% call component.panels %}Panel 3{% endcall %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal( @@ -932,7 +938,7 @@ def test_collection_variable(self): {% call component.panels %}{{ panel }} 3{% endcall %} {% endcomponent %} {% endwith %} - """ + """, ) rendered = template.render(Context({})) assert_dom_equal( @@ -979,7 +985,7 @@ def test_component_namespace(self): {% load viewcomponent_tags %} {% component 'testapp.example' %} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) expected = """ @@ -993,7 +999,7 @@ def test_component_namespace_with_parameters(self): {% load viewcomponent_tags %} {% component 'testapp.example' name="MichaelYin"%} {% endcomponent %} - """ + """, ) rendered = template.render(Context({})) expected = """ diff --git a/tests/testapp/layout.py b/tests/testapp/layout.py index 04026fe..1b6229a 100644 --- a/tests/testapp/layout.py +++ b/tests/testapp/layout.py @@ -21,7 +21,7 @@ def get_context_data(self): [ child_component.render_from_parent_context(context) for child_component in self.fields - ] + ], ) return context diff --git a/tests/utils.py b/tests/utils.py index 01c4be5..97b8d8b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -44,7 +44,7 @@ def assert_select(content, selector, equality=True, message=None, **tests): return doc.assert_select(selector, equality=equality, message=message, **tests) -class Page(object): +class Page: """ https://github.com/aroberts/assert-select @@ -54,7 +54,8 @@ class Page(object): def __init__(self, content=None, filename=None): if filename: - content = open(filename) + with open(filename) as f: + content = f.read() self.doc = BeautifulSoup(content, "html.parser") def __repr__(self): @@ -67,7 +68,7 @@ def css_select(self, selector): """ return self.doc.select(selector) - def assert_select(self, selector, equality=True, message=None, **tests): + def assert_select(self, selector, equality=True, message=None, **tests): # noqa """ Asserts that a css selector captures data from this Page, and that that data passes the test presented by the equality specifier. @@ -96,16 +97,16 @@ def assert_select(self, selector, equality=True, message=None, **tests): # set up tests equality_type = type(equality) - if equality_type == bool: + if equality_type == bool: # noqa if equality: tests["minimum"] = 1 else: tests["count"] = 0 - elif equality_type == int: + elif equality_type == int: # noqa tests["count"] = equality - elif equality_type in (str, re_type): + elif equality_type in (str, re_type): # noqa tests["text"] = equality - elif equality_type == list: + elif equality_type == list: # noqa tests["maximim"] = max(equality) tests["minimum"] = min(equality) else: @@ -119,7 +120,7 @@ def assert_select(self, selector, equality=True, message=None, **tests): elements = self.css_select(selector) if "text" in tests: match_with = tests["text"] - if type(match_with) == str: + if type(match_with) == str: # noqa filtered_elements = [e for e in elements if match_with in e.string] else: filtered_elements = [e for e in elements if match_with.match(e.string)]