From c25ef7ea6115d26f2e022c5f647282e43cde4f37 Mon Sep 17 00:00:00 2001 From: JasonCheung Date: Thu, 12 Mar 2020 17:17:15 +0800 Subject: [PATCH] fix sql session error in auto update related index --- flask_msearch/backends.py | 12 +++++++++++- test/test_whoosh.py | 22 +++++++++++++++++++--- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/flask_msearch/backends.py b/flask_msearch/backends.py index 245955b..5a39f73 100755 --- a/flask_msearch/backends.py +++ b/flask_msearch/backends.py @@ -13,7 +13,7 @@ import logging import sys -from flask_sqlalchemy import models_committed +from flask_sqlalchemy import models_committed, before_models_committed from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect from werkzeug import import_string @@ -111,8 +111,18 @@ def _signal_connect(self, app): self._signal = import_string(signal) else: self._signal = signal + before_models_committed.connect(self._before_models_committed) models_committed.connect(self.index_signal) + def _before_models_committed(self, sender, changes): + self.db.session.flush() + for instance, operation in changes: + if hasattr(instance, '__searchable__'): + for field in getattr(instance, '__searchable__', []): + if '.' in field: + splits = field.split('.') + getattr(instance, splits[0]) + def index_signal(self, sender, changes): return self._signal(self, sender, changes) diff --git a/test/test_whoosh.py b/test/test_whoosh.py index 953baaa..5408596 100644 --- a/test/test_whoosh.py +++ b/test/test_whoosh.py @@ -133,6 +133,22 @@ def test_field_search(self): results = self.Post.query.msearch('tag', fields=['tag.name']).all() self.assertEqual(len(results), 2) + results = self.Post.query.msearch('changed', fields=['title']).all() + self.assertEqual(len(results), 0) + + post2 = self.Post.query.filter(self.Post.id == post2.id).first() + post2.title = 'changed title' + self.db.session.commit() + + results = self.Post.query.msearch('changed', fields=['title']).all() + self.assertEqual(len(results), 1) + + post3 = self.Post(title=title1, content=content1, tag_id=tag1.id) + post3.save(self.db) + + results = self.Post.query.msearch('tag', fields=['tag.name']).all() + self.assertEqual(len(results), 3) + class TestSearchHybridProp(TestMixin, SearchTestBase): def setUp(self): @@ -270,10 +286,10 @@ def save(self, db): if __name__ == '__main__': suite = unittest.TestLoader().loadTestsFromNames([ - 'test_whoosh.TestSearch', + # 'test_whoosh.TestSearch', # 'test_whoosh.TestPrimaryKey', 'test_whoosh.TestRelationSearch', - 'test_whoosh.TestSearchHybridProp', - 'test_whoosh.TestHybridPropTypeHint', + # 'test_whoosh.TestSearchHybridProp', + # 'test_whoosh.TestHybridPropTypeHint', ]) unittest.TextTestRunner(verbosity=1).run(suite)