Skip to content

Commit f1ef03d

Browse files
ForeverAngryFokko
authored andcommitted
Add ExpireSnapshots (apache#1880)
## Summary This PR Closes issue apache#516 by implementing support for the `ExpireSnapshot` table metadata action. ## Rationale The `ExpireSnapshot` action is a core part of Iceberg’s table maintenance APIs. Adding support for this action in PyIceberg helps ensure feature parity with other language implementations (e.g., Java) and supports users who want to programmatically manage snapshot retention using PyIceberg’s public API. ## Testing - Unit tests have been added to cover the initial expected usage paths. - Additional feedback on edge cases, missing test scenarios or corrections to the setup test logic is greatly welcome during the review process. ## User-facing changes - This change introduces a new public API: `ExpireSnapshot`. - No breaking changes or modifications to existing APIs were made. --- --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 2a5fb97 commit f1ef03d

File tree

3 files changed

+331
-1
lines changed

3 files changed

+331
-1
lines changed

pyiceberg/table/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
update_table_metadata,
116116
)
117117
from pyiceberg.table.update.schema import UpdateSchema
118-
from pyiceberg.table.update.snapshot import ManageSnapshots, RewriteManifestsResult, UpdateSnapshot, _FastAppendFiles
118+
from pyiceberg.table.update.snapshot import ExpireSnapshots, ManageSnapshots, RewriteManifestsResult, UpdateSnapshot, _FastAppendFiles
119119
from pyiceberg.table.update.spec import UpdateSpec
120120
from pyiceberg.table.update.statistics import UpdateStatistics
121121
from pyiceberg.transforms import IdentityTransform
@@ -1202,6 +1202,10 @@ def manage_snapshots(self) -> ManageSnapshots:
12021202
"""
12031203
return ManageSnapshots(transaction=Transaction(self, autocommit=True))
12041204

1205+
def expire_snapshots(self) -> ExpireSnapshots:
1206+
"""Shorthand to run expire snapshots by id or by a timestamp."""
1207+
return ExpireSnapshots(transaction=Transaction(self, autocommit=True))
1208+
12051209
def update_statistics(self) -> UpdateStatistics:
12061210
"""
12071211
Shorthand to run statistics management operations like add statistics and remove statistics.

pyiceberg/table/update/snapshot.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pyiceberg.partitioning import (
5858
PartitionSpec,
5959
)
60+
from pyiceberg.table.refs import SnapshotRefType
6061
from pyiceberg.table.snapshots import (
6162
Operation,
6263
Snapshot,
@@ -68,6 +69,7 @@
6869
AddSnapshotUpdate,
6970
AssertRefSnapshotId,
7071
RemoveSnapshotRefUpdate,
72+
RemoveSnapshotsUpdate,
7173
SetSnapshotRefUpdate,
7274
TableRequirement,
7375
TableUpdate,
@@ -1018,3 +1020,103 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
10181020
This for method chaining
10191021
"""
10201022
return self._remove_ref_snapshot(ref_name=branch_name)
1023+
1024+
1025+
class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
1026+
"""
1027+
Expire snapshots by ID.
1028+
1029+
Use table.expire_snapshots().<operation>().commit() to run a specific operation.
1030+
Use table.expire_snapshots().<operation-one>().<operation-two>().commit() to run multiple operations.
1031+
Pending changes are applied on commit.
1032+
"""
1033+
1034+
_snapshot_ids_to_expire: Set[int] = set()
1035+
_updates: Tuple[TableUpdate, ...] = ()
1036+
_requirements: Tuple[TableRequirement, ...] = ()
1037+
1038+
def _commit(self) -> UpdatesAndRequirements:
1039+
"""
1040+
Commit the staged updates and requirements.
1041+
1042+
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
1043+
1044+
Returns:
1045+
Tuple of updates and requirements to be committed,
1046+
as required by the calling parent apply functions.
1047+
"""
1048+
# Remove any protected snapshot IDs from the set to expire, just in case
1049+
protected_ids = self._get_protected_snapshot_ids()
1050+
self._snapshot_ids_to_expire -= protected_ids
1051+
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
1052+
self._updates += (update,)
1053+
return self._updates, self._requirements
1054+
1055+
def _get_protected_snapshot_ids(self) -> Set[int]:
1056+
"""
1057+
Get the IDs of protected snapshots.
1058+
1059+
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
1060+
1061+
Returns:
1062+
Set of protected snapshot IDs to exclude from expiration.
1063+
"""
1064+
protected_ids: Set[int] = set()
1065+
1066+
for ref in self._transaction.table_metadata.refs.values():
1067+
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
1068+
protected_ids.add(ref.snapshot_id)
1069+
1070+
return protected_ids
1071+
1072+
def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
1073+
"""
1074+
Expire a snapshot by its ID.
1075+
1076+
This will mark the snapshot for expiration.
1077+
1078+
Args:
1079+
snapshot_id (int): The ID of the snapshot to expire.
1080+
Returns:
1081+
This for method chaining.
1082+
"""
1083+
if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None:
1084+
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
1085+
1086+
if snapshot_id in self._get_protected_snapshot_ids():
1087+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
1088+
1089+
self._snapshot_ids_to_expire.add(snapshot_id)
1090+
1091+
return self
1092+
1093+
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots":
1094+
"""
1095+
Expire multiple snapshots by their IDs.
1096+
1097+
This will mark the snapshots for expiration.
1098+
1099+
Args:
1100+
snapshot_ids (List[int]): List of snapshot IDs to expire.
1101+
Returns:
1102+
This for method chaining.
1103+
"""
1104+
for snapshot_id in snapshot_ids:
1105+
self.expire_snapshot_by_id(snapshot_id)
1106+
return self
1107+
1108+
def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
1109+
"""
1110+
Expire all unprotected snapshots with a timestamp older than a given value.
1111+
1112+
Args:
1113+
timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired.
1114+
1115+
Returns:
1116+
This for method chaining.
1117+
"""
1118+
protected_ids = self._get_protected_snapshot_ids()
1119+
for snapshot in self._transaction.table_metadata.snapshots:
1120+
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
1121+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1122+
return self

tests/table/test_expire_snapshots.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from unittest.mock import MagicMock
18+
from uuid import uuid4
19+
20+
import pytest
21+
22+
from pyiceberg.table import CommitTableResponse, Table
23+
24+
25+
def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None:
26+
"""Test that a HEAD (branch) snapshot cannot be expired."""
27+
HEAD_SNAPSHOT = 3051729675574597004
28+
KEEP_SNAPSHOT = 3055729675574597004
29+
30+
# Mock the catalog's commit_table method
31+
table_v2.catalog = MagicMock()
32+
# Simulate refs protecting HEAD_SNAPSHOT as a branch
33+
table_v2.metadata = table_v2.metadata.model_copy(
34+
update={
35+
"refs": {
36+
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
37+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
38+
}
39+
}
40+
)
41+
# Assert fixture data
42+
assert any(ref.snapshot_id == HEAD_SNAPSHOT for ref in table_v2.metadata.refs.values())
43+
44+
# Attempt to expire the HEAD snapshot and expect a ValueError
45+
with pytest.raises(ValueError, match=f"Snapshot with ID {HEAD_SNAPSHOT} is protected and cannot be expired."):
46+
table_v2.expire_snapshots().expire_snapshot_by_id(HEAD_SNAPSHOT).commit()
47+
48+
table_v2.catalog.commit_table.assert_not_called()
49+
50+
51+
def test_cannot_expire_tagged_snapshot(table_v2: Table) -> None:
52+
"""Test that a tagged snapshot cannot be expired."""
53+
TAGGED_SNAPSHOT = 3051729675574597004
54+
KEEP_SNAPSHOT = 3055729675574597004
55+
56+
table_v2.catalog = MagicMock()
57+
# Simulate refs protecting TAGGED_SNAPSHOT as a tag
58+
table_v2.metadata = table_v2.metadata.model_copy(
59+
update={
60+
"refs": {
61+
"tag1": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
62+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
63+
}
64+
}
65+
)
66+
assert any(ref.snapshot_id == TAGGED_SNAPSHOT for ref in table_v2.metadata.refs.values())
67+
68+
with pytest.raises(ValueError, match=f"Snapshot with ID {TAGGED_SNAPSHOT} is protected and cannot be expired."):
69+
table_v2.expire_snapshots().expire_snapshot_by_id(TAGGED_SNAPSHOT).commit()
70+
71+
table_v2.catalog.commit_table.assert_not_called()
72+
73+
74+
def test_expire_unprotected_snapshot(table_v2: Table) -> None:
75+
"""Test that an unprotected snapshot can be expired."""
76+
EXPIRE_SNAPSHOT = 3051729675574597004
77+
KEEP_SNAPSHOT = 3055729675574597004
78+
79+
mock_response = CommitTableResponse(
80+
metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}),
81+
metadata_location="mock://metadata/location",
82+
uuid=uuid4(),
83+
)
84+
table_v2.catalog = MagicMock()
85+
table_v2.catalog.commit_table.return_value = mock_response
86+
87+
# Remove any refs that protect the snapshot to be expired
88+
table_v2.metadata = table_v2.metadata.model_copy(
89+
update={
90+
"refs": {
91+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
92+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
93+
}
94+
}
95+
)
96+
97+
# Assert fixture data
98+
assert all(ref.snapshot_id != EXPIRE_SNAPSHOT for ref in table_v2.metadata.refs.values())
99+
100+
# Expire the snapshot
101+
table_v2.expire_snapshots().expire_snapshot_by_id(EXPIRE_SNAPSHOT).commit()
102+
103+
table_v2.catalog.commit_table.assert_called_once()
104+
remaining_snapshots = table_v2.metadata.snapshots
105+
assert EXPIRE_SNAPSHOT not in remaining_snapshots
106+
assert len(table_v2.metadata.snapshots) == 1
107+
108+
109+
def test_expire_nonexistent_snapshot_raises(table_v2: Table) -> None:
110+
"""Test that trying to expire a non-existent snapshot raises an error."""
111+
NONEXISTENT_SNAPSHOT = 9999999999999999999
112+
113+
table_v2.catalog = MagicMock()
114+
table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}})
115+
116+
with pytest.raises(ValueError, match=f"Snapshot with ID {NONEXISTENT_SNAPSHOT} does not exist."):
117+
table_v2.expire_snapshots().expire_snapshot_by_id(NONEXISTENT_SNAPSHOT).commit()
118+
119+
table_v2.catalog.commit_table.assert_not_called()
120+
121+
122+
def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
123+
# Setup: two snapshots; both are old, but one is head/tag protected
124+
HEAD_SNAPSHOT = 3051729675574597004
125+
TAGGED_SNAPSHOT = 3055729675574597004
126+
127+
# Add snapshots to metadata for timestamp/protected test
128+
from types import SimpleNamespace
129+
130+
table_v2.metadata = table_v2.metadata.model_copy(
131+
update={
132+
"refs": {
133+
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
134+
"mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
135+
},
136+
"snapshots": [
137+
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
138+
SimpleNamespace(snapshot_id=TAGGED_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
139+
],
140+
}
141+
)
142+
table_v2.catalog = MagicMock()
143+
144+
# Attempt to expire all snapshots before a future timestamp (so both are candidates)
145+
future_timestamp = 9999999999999 # Far in the future, after any real snapshot
146+
147+
# Mock the catalog's commit_table to return the current metadata (simulate no change)
148+
mock_response = CommitTableResponse(
149+
metadata=table_v2.metadata, # protected snapshots remain
150+
metadata_location="mock://metadata/location",
151+
uuid=uuid4(),
152+
)
153+
table_v2.catalog.commit_table.return_value = mock_response
154+
155+
table_v2.expire_snapshots().expire_snapshots_older_than(future_timestamp).commit()
156+
# Update metadata to reflect the commit (as in other tests)
157+
table_v2.metadata = mock_response.metadata
158+
159+
# Both protected snapshots should remain
160+
remaining_ids = {s.snapshot_id for s in table_v2.metadata.snapshots}
161+
assert HEAD_SNAPSHOT in remaining_ids
162+
assert TAGGED_SNAPSHOT in remaining_ids
163+
164+
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
165+
args, kwargs = table_v2.catalog.commit_table.call_args
166+
updates = args[2] if len(args) > 2 else ()
167+
# Find RemoveSnapshotsUpdate in updates
168+
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
169+
assert remove_update is not None
170+
assert remove_update.snapshot_ids == []
171+
172+
173+
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
174+
"""Test that multiple unprotected snapshots can be expired by IDs."""
175+
EXPIRE_SNAPSHOT_1 = 3051729675574597004
176+
EXPIRE_SNAPSHOT_2 = 3051729675574597005
177+
KEEP_SNAPSHOT = 3055729675574597004
178+
179+
mock_response = CommitTableResponse(
180+
metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}),
181+
metadata_location="mock://metadata/location",
182+
uuid=uuid4(),
183+
)
184+
table_v2.catalog = MagicMock()
185+
table_v2.catalog.commit_table.return_value = mock_response
186+
187+
# Remove any refs that protect the snapshots to be expired
188+
table_v2.metadata = table_v2.metadata.model_copy(
189+
update={
190+
"refs": {
191+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
192+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
193+
}
194+
}
195+
)
196+
197+
# Add snapshots to metadata for multi-id test
198+
from types import SimpleNamespace
199+
200+
table_v2.metadata = table_v2.metadata.model_copy(
201+
update={
202+
"refs": {
203+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
204+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
205+
},
206+
"snapshots": [
207+
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
208+
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_2, timestamp_ms=1, parent_snapshot_id=None),
209+
SimpleNamespace(snapshot_id=KEEP_SNAPSHOT, timestamp_ms=2, parent_snapshot_id=None),
210+
],
211+
}
212+
)
213+
214+
# Assert fixture data
215+
assert all(ref.snapshot_id not in (EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2) for ref in table_v2.metadata.refs.values())
216+
217+
# Expire the snapshots
218+
table_v2.expire_snapshots().expire_snapshots_by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2]).commit()
219+
220+
table_v2.catalog.commit_table.assert_called_once()
221+
remaining_snapshots = table_v2.metadata.snapshots
222+
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
223+
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots
224+
assert len(table_v2.metadata.snapshots) == 1

0 commit comments

Comments
 (0)