@@ -106,62 +106,94 @@ class DownstreamMemoryInfo:
106106 object_store_memory : float
107107
108108
109- class RefBundleDeque (deque ):
110- """Thread-safe wrapper around collections.deque that stores current stats."""
109+ class OpBufferQueue :
110+ """A FIFO queue to buffer RefBundles between upstream and downstream operators.
111+ This class is thread-safe.
112+ """
111113
112114 def __init__ (self ):
113115 self ._memory_usage = 0
114116 self ._num_blocks = 0
117+ self ._queue = deque ()
118+ self ._num_per_split = defaultdict (int )
115119 self ._lock = threading .Lock ()
116120 super ().__init__ ()
117121
118122 @property
119123 def memory_usage (self ) -> int :
124+ """The total memory usage of the queue in bytes."""
120125 with self ._lock :
121126 return self ._memory_usage
122127
123128 @property
124129 def num_blocks (self ) -> int :
130+ """The total number of blocks in the queue."""
125131 with self ._lock :
126132 return self ._num_blocks
127133
128- def append (self , ref : RefBundle ):
129- with self ._lock :
130- self ._memory_usage += ref .size_bytes ()
131- self ._num_blocks += len (ref .blocks )
132- super ().append (ref )
134+ def __len__ (self ):
135+ return len (self ._queue )
133136
134- def appendleft (self , ref : RefBundle ):
135- with self ._lock :
136- self ._memory_usage += ref .size_bytes ()
137- self ._num_blocks += len (ref .blocks )
138- super ().appendleft (ref )
137+ def has_next (self , output_split_idx : Optional [int ] = None ) -> bool :
138+ """Whether next RefBundle is available.
139139
140- def pop (self ) -> RefBundle :
141- ref = super ().pop ()
142- with self ._lock :
143- self ._memory_usage -= ref .size_bytes ()
144- self ._num_blocks -= len (ref .blocks )
145- return ref
140+ Args:
141+ output_split_idx: If specified, only check ref bundles with the
142+ given output split.
143+ """
144+ if output_split_idx is None :
145+ return len (self ._queue ) > 0
146+ else :
147+ with self ._lock :
148+ return self ._num_per_split [output_split_idx ] > 0
146149
147- def popleft (self ) -> RefBundle :
148- ref = super ().popleft ()
150+ def append (self , ref : RefBundle ):
151+ """Append a RefBundle to the queue."""
152+ self ._queue .append (ref )
149153 with self ._lock :
150- self ._memory_usage -= ref .size_bytes ()
151- self ._num_blocks -= len (ref .blocks )
152- return ref
153-
154- def remove (self , ref : RefBundle ):
155- super ().remove (ref )
154+ self ._memory_usage += ref .size_bytes ()
155+ self ._num_blocks += len (ref .blocks )
156+ if ref .output_split_idx is not None :
157+ self ._num_per_split [ref .output_split_idx ] += 1
158+
159+ def pop (self , output_split_idx : Optional [int ] = None ) -> Optional [RefBundle ]:
160+ """Pop a RefBundle from the queue.
161+ Args:
162+ output_split_idx: If specified, only pop a RefBundle
163+ with the given output split.
164+ Returns:
165+ A RefBundle if available, otherwise None.
166+ """
167+ ret = None
168+ if output_split_idx is None :
169+ try :
170+ ret = self ._queue .popleft ()
171+ except IndexError :
172+ pass
173+ else :
174+ # TODO(hchen): Index the queue by output_split_idx to
175+ # avoid linear scan.
176+ for i in range (len (self ._queue )):
177+ ref = self ._queue [i ]
178+ if ref .output_split_idx == output_split_idx :
179+ ret = ref
180+ del self ._queue [i ]
181+ break
182+ if ret is None :
183+ return None
156184 with self ._lock :
157- self ._memory_usage -= ref .size_bytes ()
158- self ._num_blocks -= len (ref .blocks )
185+ self ._memory_usage -= ret .size_bytes ()
186+ self ._num_blocks -= len (ret .blocks )
187+ if ret .output_split_idx is not None :
188+ self ._num_per_split [ret .output_split_idx ] -= 1
189+ return ret
159190
160191 def clear (self ):
161- super ().clear ()
162192 with self ._lock :
193+ self ._queue .clear ()
163194 self ._memory_usage = 0
164195 self ._num_blocks = 0
196+ self ._num_per_split .clear ()
165197
166198
167199class OpState :
@@ -174,17 +206,17 @@ class OpState:
174206 operator queues to be shared across threads.
175207 """
176208
177- def __init__ (self , op : PhysicalOperator , inqueues : List [RefBundleDeque ]):
209+ def __init__ (self , op : PhysicalOperator , inqueues : List [OpBufferQueue ]):
178210 # Each inqueue is connected to another operator's outqueue.
179211 assert len (inqueues ) == len (op .input_dependencies ), (op , inqueues )
180- self .inqueues : List [RefBundleDeque ] = inqueues
212+ self .inqueues : List [OpBufferQueue ] = inqueues
181213 # The outqueue is connected to another operator's inqueue (they physically
182214 # share the same Python list reference).
183215 #
184216 # Note: this queue is also accessed concurrently from the consumer thread.
185217 # (in addition to the streaming executor thread). Hence, it must be a
186218 # thread-safe type such as `deque`.
187- self .outqueue : RefBundleDeque = RefBundleDeque ()
219+ self .outqueue : OpBufferQueue = OpBufferQueue ()
188220 self .op = op
189221 self .progress_bar = None
190222 self .num_completed_tasks = 0
@@ -266,8 +298,9 @@ def summary_str(self) -> str:
266298 def dispatch_next_task (self ) -> None :
267299 """Move a bundle from the operator inqueue to the operator itself."""
268300 for i , inqueue in enumerate (self .inqueues ):
269- if inqueue :
270- self .op .add_input (inqueue .popleft (), input_index = i )
301+ ref = inqueue .pop ()
302+ if ref is not None :
303+ self .op .add_input (ref , input_index = i )
271304 return
272305 assert False , "Nothing to dispatch"
273306
@@ -285,24 +318,11 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> RefBundle:
285318 # Check if StreamingExecutor has caught an exception or is done execution.
286319 if self ._exception is not None :
287320 raise self ._exception
288- elif self ._finished and len ( self .outqueue ) == 0 :
321+ elif self ._finished and not self .outqueue . has_next ( output_split_idx ) :
289322 raise StopIteration ()
290- try :
291- # Non-split output case.
292- if output_split_idx is None :
293- return self .outqueue .popleft ()
294-
295- # Scan the queue and look for outputs tagged for the given index.
296- for i in range (len (self .outqueue )):
297- bundle = self .outqueue [i ]
298- if bundle .output_split_idx == output_split_idx :
299- self .outqueue .remove (bundle )
300- return bundle
301-
302- # Didn't find any outputs matching this index, repeat the loop until
303- # we find one or hit a None.
304- except IndexError :
305- pass
323+ ref = self .outqueue .pop (output_split_idx )
324+ if ref is not None :
325+ return ref
306326 time .sleep (0.01 )
307327
308328 def inqueue_memory_usage (self ) -> int :
0 commit comments