diff --git a/pyproject.toml b/pyproject.toml index 283338a838..c91608b6ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [build-system] requires = [ + "protobuf<5", "grpcio-tools>=1.56.2,<2", "mypy-protobuf>=3.1", "pybindgen==0.22.0", diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 8f47fab077..042eee06ab 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -19,6 +19,7 @@ from feast.permissions.permission import Permission from feast.project import Project from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.saved_dataset import SavedDataset, ValidationReference from feast.stream_feature_view import StreamFeatureView from feast.utils import _utc_now @@ -28,13 +29,14 @@ class CachingRegistry(BaseRegistry): def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str): - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = _utc_now() + self.cache_mode = cache_mode + self.cached_registry_proto = RegistryProto() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0 ) - self.cache_mode = cache_mode + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = _utc_now() if cache_mode == "thread": self._start_thread_async_refresh(cache_ttl_seconds) atexit.register(self._exit_handler) @@ -429,20 +431,26 @@ def refresh(self, project: Optional[str] = None): def _refresh_cached_registry_if_necessary(self): if self.cache_mode == "sync": with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl + if self.cached_registry_proto == RegistryProto(): + # Avoids the need to refresh the registry when cache is not populated yet + # Specially during the __init__ phase + # proto() will populate the cache with project metadata if no objects are registered + expired = False + else: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) ) ) - ) if expired: logger.info("Registry cache expired, so refreshing") self.refresh() diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index d6a716e082..a6a2417c6e 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -251,6 +251,8 @@ def __init__( registry_config, SqlRegistryConfig ), "SqlRegistry needs a valid registry_config" + self.registry_config = registry_config + self.write_engine: Engine = create_engine( registry_config.path, **registry_config.sqlalchemy_config_kwargs ) @@ -281,7 +283,7 @@ def __init__( def _sync_feast_metadata_to_projects_table(self): feast_metadata_projects: set = [] projects_set: set = [] - with self.write_engine.begin() as conn: + with self.read_engine.begin() as conn: stmt = select(feast_metadata).where( feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value ) @@ -290,7 +292,7 @@ def _sync_feast_metadata_to_projects_table(self): feast_metadata_projects.append(row._mapping["project_id"]) if len(feast_metadata_projects) > 0: - with self.write_engine.begin() as conn: + with self.read_engine.begin() as conn: stmt = select(projects) rows = conn.execute(stmt).all() for row in rows: diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index 5dc2509333..0bed89ca16 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -1767,3 +1767,64 @@ def test_apply_entity_success_with_purge_feast_metadata(test_registry): assert len(entities) == 0 test_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "test_registry", + sql_fixtures + async_sql_fixtures, +) +def test_apply_entity_to_sql_registry_and_reinitialize_sql_registry(test_registry): + entity = Entity( + name="driver_car_id", + description="Car driver id", + tags={"team": "matchmaking"}, + ) + + project = "project" + + # Register Entity + test_registry.apply_entity(entity, project) + assert_project(project, test_registry) + + entities = test_registry.list_entities(project, tags=entity.tags) + assert_project(project, test_registry) + + entity = entities[0] + assert ( + len(entities) == 1 + and entity.name == "driver_car_id" + and entity.description == "Car driver id" + and "team" in entity.tags + and entity.tags["team"] == "matchmaking" + ) + + entity = test_registry.get_entity("driver_car_id", project) + assert ( + entity.name == "driver_car_id" + and entity.description == "Car driver id" + and "team" in entity.tags + and entity.tags["team"] == "matchmaking" + ) + + # After the first apply, the created_timestamp should be the same as the last_update_timestamp. + assert entity.created_timestamp == entity.last_updated_timestamp + updated_test_registry = SqlRegistry(test_registry.registry_config, "project", None) + + # Update entity + updated_entity = Entity( + name="driver_car_id", + description="Car driver Id", + tags={"team": "matchmaking"}, + ) + updated_test_registry.apply_entity(updated_entity, project) + + updated_entity = updated_test_registry.get_entity("driver_car_id", project) + updated_test_registry.delete_entity("driver_car_id", project) + assert_project(project, updated_test_registry) + entities = updated_test_registry.list_entities(project) + assert_project(project, updated_test_registry) + assert len(entities) == 0 + + updated_test_registry.teardown() + test_registry.teardown()