diff --git a/tap_postgres/client.py b/tap_postgres/client.py index bba728e..2babaf5 100644 --- a/tap_postgres/client.py +++ b/tap_postgres/client.py @@ -429,7 +429,7 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, Any]]: while True: message = logical_replication_cursor.read_message() if message: - row = self.consume(message) + row = self.consume(message, logical_replication_cursor) if row: yield row else: @@ -456,7 +456,7 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, Any]]: logical_replication_cursor.close() logical_replication_connection.close() - def consume(self, message) -> dict | None: + def consume(self, message, cursor) -> dict | None: """Ingest WAL message.""" try: message_payload = json.loads(message.payload) @@ -476,12 +476,12 @@ def consume(self, message) -> dict | None: if message_payload["action"] in upsert_actions: for column in message_payload["columns"]: - row.update({column["name"]: column["value"]}) + row.update({column["name"]: self._parse_column_value(column, cursor)}) row.update({"_sdc_deleted_at": None}) row.update({"_sdc_lsn": message.data_start}) elif message_payload["action"] in delete_actions: for column in message_payload["identity"]: - row.update({column["name"]: column["value"]}) + row.update({column["name"]: self._parse_column_value(column, cursor)}) row.update( { "_sdc_deleted_at": datetime.datetime.utcnow().strftime( @@ -517,6 +517,15 @@ def consume(self, message) -> dict | None: return row + def _parse_column_value(self, column, cursor): + # When using log based replication, the wal2json output for columns of + # array types returns a string encoded in sql format, e.g. '{a,b}' + # https://github.com/eulerto/wal2json/issues/221#issuecomment-1025143441 + if column["type"] == "text[]": + return psycopg2.extensions.STRINGARRAY(column["value"], cursor) + + return column["value"] + def logical_replication_connection(self): """A logical replication connection to the database. diff --git a/tests/test_log_based.py b/tests/test_log_based.py index 51075a7..5656b08 100644 --- a/tests/test_log_based.py +++ b/tests/test_log_based.py @@ -1,7 +1,7 @@ import json import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import BIGINT, TEXT +from sqlalchemy.dialects.postgresql import ARRAY, BIGINT, TEXT from tap_postgres.tap import TapPostgres from tests.test_core import PostgresTestRunner @@ -57,3 +57,53 @@ def test_null_append(): tap_class=TapPostgres, config=LOG_BASED_CONFIG, catalog=tap_catalog ) test_runner.sync_all() + + +def test_string_array_column(): + """LOG_BASED syncs failed with array columns. + + This test checks that even when a catalog contains properties with types represented + as arrays (ex: "text[]") LOG_BASED replication can properly decode their value. + """ + table_name = "test_array_column" + engine = sa.create_engine( + "postgresql://postgres:postgres@localhost:5434/postgres", future=True + ) + + metadata_obj = sa.MetaData() + table = sa.Table( + table_name, + metadata_obj, + sa.Column("id", BIGINT, primary_key=True), + sa.Column("data", ARRAY(TEXT), nullable=True), + ) + with engine.begin() as conn: + table.drop(conn, checkfirst=True) + metadata_obj.create_all(conn) + insert = table.insert().values(id=123, data=["1", "2"]) + conn.execute(insert) + insert = table.insert().values(id=321, data=['This is a "test"', "2"]) + conn.execute(insert) + + tap = TapPostgres(config=LOG_BASED_CONFIG) + tap_catalog = json.loads(tap.catalog_json_text) + altered_table_name = f"public-{table_name}" + + for stream in tap_catalog["streams"]: + if stream.get("stream") and altered_table_name not in stream["stream"]: + for metadata in stream["metadata"]: + metadata["metadata"]["selected"] = False + else: + stream["replication_method"] = "LOG_BASED" + stream["replication_key"] = "_sdc_lsn" + for metadata in stream["metadata"]: + metadata["metadata"]["selected"] = True + if metadata["breadcrumb"] == []: + metadata["metadata"]["replication-method"] = "LOG_BASED" + + test_runner = PostgresTestRunner( + tap_class=TapPostgres, config=LOG_BASED_CONFIG, catalog=tap_catalog + ) + test_runner.sync_all() + assert test_runner.record_messages[0]["record"]["data"] == ["1", "2"] + assert test_runner.record_messages[1]["record"]["data"] == ['This is a "test"', "2"]