Skip to content

Commit bdab20c

Browse files
committed
Make Metadata class work correctly if entered multiple times.
PiperOrigin-RevId: 425456552
1 parent 0784a1f commit bdab20c

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tfx/orchestration/kubeflow/container_entrypoint_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def testOverrideRegisterExecution(self):
129129
(python_driver_operator.PythonDriverOperator, 'run_driver',
130130
driver_output_pb2.DriverOutput()),
131131
(metadata.Metadata, '__init__', None),
132+
(metadata.Metadata, '__exit__', None),
132133
(launcher.Launcher, '_publish_successful_execution', None),
133134
(launcher.Launcher, '_clean_up_stateless_execution_info', None),
134135
(launcher.Launcher, '_clean_up_stateful_execution_info', None),

tfx/orchestration/metadata.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,15 @@ def mysql_metadata_connection_config(
125125
# TODO(ruoyu): Figure out the story mutable UDFs. We should not reuse previous
126126
# run when having different UDFs.
127127
class Metadata:
128-
"""Helper class to handle metadata I/O."""
128+
"""Helper class to handle metadata I/O.
129+
130+
Not thread-safe without external synchronisation.
131+
"""
129132

130133
def __init__(self, connection_config: ConnectionConfigType) -> None:
131134
self._connection_config = connection_config
132135
self._store = None
136+
self._users = 0
133137

134138
def __enter__(self) -> 'Metadata':
135139
# TODO(ruoyu): Establishing a connection pool instead of newing
@@ -147,6 +151,7 @@ def __enter__(self) -> 'Metadata':
147151
time.sleep(random.random())
148152
continue
149153
else:
154+
self._users += 1
150155
return self
151156

152157
raise RuntimeError(
@@ -156,7 +161,9 @@ def __enter__(self) -> 'Metadata':
156161
def __exit__(self, exc_type: Optional[Type[Exception]],
157162
exc_value: Optional[Exception],
158163
exc_tb: Optional[types.TracebackType]) -> None:
159-
self._store = None
164+
self._users -= 1
165+
if self._users == 0:
166+
self._store = None
160167

161168
@property
162169
def store(self) -> mlmd.MetadataStore:

0 commit comments

Comments
 (0)