1212from multiprocessing import Process , Queue , SimpleQueue , cpu_count , get_context
1313from multiprocessing .context import SpawnContext , SpawnProcess
1414from multiprocessing .managers import SyncManager
15+ from queue import Empty
1516from threading import Thread
1617from time import sleep
1718from uuid import uuid4
@@ -72,13 +73,12 @@ class GiskardTask:
7273@dataclass (frozen = True )
7374class GiskardResult :
7475 id : str
76+ logs : str
7577 result : Any = None
7678 exception : Any = None
77- logs : str = None
7879
7980
8081def _process_worker (tasks_queue : SimpleQueue , tasks_results : SimpleQueue , running_process : Dict [str , str ]):
81- # TODO(Bazire): For now, worker stops because the queues will be closed. Not the cleanest...
8282 pid = os .getpid ()
8383
8484 while True :
@@ -93,7 +93,7 @@ def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, runnin
9393 handler = logging .StreamHandler (f )
9494 logging .getLogger ().addHandler (handler )
9595 try :
96- LOGGER .debug ("Doing task %" , task .id )
96+ LOGGER .debug ("Doing task %s " , task .id )
9797 running_process [task .id ] = pid
9898 result = task .callable (* task .args , ** task .kwargs )
9999 to_return = GiskardResult (id = task .id , result = result , logs = f .getvalue ())
@@ -119,7 +119,8 @@ class PoolState(Enum):
119119
120120
121121class WorkerPoolExecutor (Executor ):
122- def __init__ (self , nb_workers : Optional [int ] = None ):
122+ def __init__ (self , nb_workers : Optional [int ] = None , name : Optional [str ] = None ):
123+ self ._prefix = f"{ name } _" if name is not None else "giskard_pool_"
123124 if nb_workers is None :
124125 nb_workers = cpu_count ()
125126 if nb_workers <= 0 :
@@ -135,33 +136,42 @@ def __init__(self, nb_workers: Optional[int] = None):
135136 # Mapping of the running tasks and worker pids
136137 self ._running_process : Dict [str , str ] = self ._manager .dict ()
137138 # Mapping of the running tasks and worker pids
138- self ._with_timeout_tasks : List [Tuple [str , float ]] = []
139- # Queue with tasks to run
140- self ._pending_tasks_queue : SimpleQueue [GiskardTask ] = self ._mp_context .SimpleQueue ()
139+ self ._with_timeout_tasks : List [TimeoutData ] = []
141140 # Queue with tasks to run
141+ self ._pending_tasks_queue : Queue [GiskardTask ] = self ._mp_context .Queue ()
142+ # Queue with tasks to be consumed asap
142143 # As in ProcessPool, add one more to avoid idling process
143- self ._running_tasks_queue : Queue [GiskardTask ] = self ._mp_context .Queue (maxsize = self ._nb_workers + 1 )
144+ self ._running_tasks_queue : Queue [Optional [ GiskardTask ] ] = self ._mp_context .Queue (maxsize = self ._nb_workers + 1 )
144145 # Queue with results to notify
145- self ._tasks_results : SimpleQueue [GiskardResult ] = self ._mp_context .SimpleQueue ()
146+ self ._tasks_results : Queue [GiskardResult ] = self ._mp_context .Queue ()
146147 # Mapping task_id with future
147148 self ._futures_mapping : Dict [str , Future ] = dict ()
149+ LOGGER .debug ("Starting threads for the WorkerPoolExecutor" )
148150
149151 self ._threads = [
150- Thread (name = "giskard_pool_" + target .__name__ , target = target , daemon = True , args = [self ], kwargs = None )
152+ Thread (name = f" { self . _prefix } { target .__name__ } " , target = target , daemon = True , args = [self ], kwargs = None )
151153 for target in [_killer_thread , _feeder_thread , _results_thread ]
152154 ]
153155 for t in self ._threads :
154156 t .start ()
157+ LOGGER .debug ("Threads started, spawning workers..." )
158+ LOGGER .debug ("Starting the pool with %s" , {self ._nb_workers })
155159
156160 # Startup workers
157161 for _ in range (self ._nb_workers ):
158162 self ._spawn_worker ()
163+ LOGGER .info ("WorkerPoolExecutor is started" )
164+
165+ def health_check (self ):
166+ if any ([not p .is_alive () for p in self ._processes .values ()]):
167+ LOGGER .warning ("At least one process died for an unknown reason, marking pool as broken" )
168+ self ._state = PoolState .BROKEN
159169
160170 def _spawn_worker (self ):
161171 # Daemon means process are linked to main one, and will be stopped if current process is stopped
162172 p = self ._mp_context .Process (
163173 target = _process_worker ,
164- name = "giskard_worker_process " ,
174+ name = f" { self . _prefix } _worker_process " ,
165175 args = (self ._running_tasks_queue , self ._tasks_results , self ._running_process ),
166176 daemon = True ,
167177 )
@@ -190,7 +200,7 @@ def schedule(
190200 self ._futures_mapping [task .id ] = res
191201 self ._pending_tasks_queue .put (task )
192202 if timeout is not None :
193- self ._with_timeout_tasks .append ((task .id , time .monotonic () + timeout ))
203+ self ._with_timeout_tasks .append (TimeoutData (task .id , time .monotonic () + timeout ))
194204 return res
195205
196206 def shutdown (self , wait = True , timeout : float = 5 ):
@@ -207,8 +217,14 @@ def shutdown(self, wait=True, timeout: float = 5):
207217 while not self ._running_tasks_queue .empty ():
208218 try :
209219 self ._running_tasks_queue .get_nowait ()
210- except BaseException :
211- pass
220+ except ValueError as e :
221+ # This happens if queues is closed
222+ LOGGER .warning ("Running task queue is already closed" )
223+ LOGGER .exception (e )
224+ except Empty as e :
225+ # May happen if a process consume an element
226+ LOGGER .warning ("Queue was empty, skipping" )
227+ LOGGER .exception (e )
212228 # Try to nicely stop the worker, by adding None into the running tasks
213229 for _ in range (self ._nb_workers ):
214230 self ._running_tasks_queue .put (None , timeout = 1 )
@@ -235,92 +251,97 @@ def shutdown(self, wait=True, timeout: float = 5):
235251 return exit_codes
236252
237253
254+ def _safe_get (queue : Queue , executor : WorkerPoolExecutor , timeout : float = 1 ) -> Tuple [Any , bool ]:
255+ try :
256+ result = queue .get (timeout = 1 )
257+ except Empty :
258+ result = None
259+ except ValueError as e :
260+ # If queue is closed
261+ if executor ._state not in FINAL_STATES :
262+ LOGGER .error ("Queue is closed, and executor not in final state" )
263+ executor ._state = PoolState .BROKEN
264+ raise e
265+ return None , True
266+ if executor ._state in FINAL_STATES :
267+ return None , True
268+ return result , False
269+
270+
238271def _results_thread (
239272 executor : WorkerPoolExecutor ,
240273):
241274 # Goal of this thread is to feed the running tasks from pending one as soon as possible
242- while True :
243- while executor ._state not in FINAL_STATES and (executor ._tasks_results .empty ()):
244- # TODO(Bazire): find a way to improve this ?
245- # Cannot use select, since we want to be windows compatible
246- sleep (0.01 )
247- if executor ._state in FINAL_STATES :
275+ # while True:
276+ while executor ._state not in FINAL_STATES :
277+ result , should_stop = _safe_get (executor ._tasks_results , executor )
278+ if should_stop :
248279 return
249- result = executor ._tasks_results .get ()
250- future = executor ._futures_mapping .get (result .id )
251- if future .cancelled ():
252- try :
253- del executor ._futures_mapping [result .id ]
254- except BaseException :
255- pass
280+ if result is None :
281+ continue
282+
283+ future = executor ._futures_mapping .pop (result .id , None )
284+ if future is None or future .cancelled ():
256285 continue
257286 future .logs = result .logs
258287 if result .exception is None :
259288 future .set_result (result .result )
260289 else :
261- # TODO(Bazire): improve to get Traceback
262290 future .set_exception (RuntimeError (result .exception ))
263- try :
264- del executor ._futures_mapping [result .id ]
265- except BaseException :
266- pass
267291
268292
269293def _feeder_thread (
270294 executor : WorkerPoolExecutor ,
271295):
272296 # Goal of this thread is to feed the running tasks from pending one as soon as possible
273- while True :
274- while executor ._state not in FINAL_STATES and (
275- executor ._running_tasks_queue .full () or executor ._pending_tasks_queue .empty ()
276- ):
277- # TODO(Bazire): find a way to improve this ?
278- # Cannot use select, since we want to be windows compatible
279- sleep (0.01 )
280- if executor ._state in FINAL_STATES :
297+ while executor ._state not in FINAL_STATES :
298+ task , should_stop = _safe_get (executor ._pending_tasks_queue , executor )
299+ if should_stop :
281300 return
282- task = executor ._pending_tasks_queue .get ()
301+ if task is None :
302+ continue
303+
283304 future = executor ._futures_mapping .get (task .id )
284- if future is not None and future .set_running_or_notify_cancel ():
305+ if future is None :
306+ continue
307+ if future .set_running_or_notify_cancel ():
285308 executor ._running_tasks_queue .put (task )
286- elif future is not None :
309+ else :
287310 # Future has been cancelled already, nothing to do
288- del executor ._futures_mapping [ task .id ]
311+ executor ._futures_mapping . pop ( task .id , False )
289312
290313
291314def _killer_thread (
292315 executor : WorkerPoolExecutor ,
293316):
294- while True :
317+ while executor . _state not in FINAL_STATES :
295318 while len (executor ._with_timeout_tasks ) == 0 and executor ._state not in FINAL_STATES :
296319 # No need to be too active
297320 sleep (1 )
321+ executor .health_check ()
298322 if executor ._state in FINAL_STATES :
299323 return
300324
301- clean_up = []
325+ clean_up : List [ TimeoutData ] = []
302326 exception = None
303327 for timeout_data in executor ._with_timeout_tasks :
304- task_id , end_time = timeout_data
305- if task_id not in executor ._futures_mapping :
328+ if timeout_data .id not in executor ._futures_mapping :
306329 # Task is already completed, do not wait for it
307330 clean_up .append (timeout_data )
308- elif time .monotonic () > end_time :
331+ elif time .monotonic () > timeout_data . end_time :
309332 # Task has timed out, we should kill it
310333 try :
311- pid = executor ._running_process .get (task_id )
312- if pid is None :
313- # Task must have finished
314- continue
315- future = executor ._futures_mapping .get (task_id )
316- if not future .cancel ():
334+ future = executor ._futures_mapping .pop (timeout_data .id , None )
335+ if future is not None and not future .cancel ():
317336 LOGGER .warning ("Killing a timed out process" )
318337 future .set_exception (TimeoutError ("Task took too long" ))
319- p = executor ._processes [pid ]
320- del executor ._processes [pid ]
321- _stop_processes ([p ])
322- executor ._spawn_worker ()
338+ pid = executor ._running_process .pop (timeout_data .id , None )
339+ if pid is not None :
340+ p = executor ._processes .pop (pid )
341+ _stop_processes ([p ])
342+ executor ._spawn_worker ()
323343 except BaseException as e :
344+ # This is probably an OSError, but we want to be extra safe
324345 LOGGER .warning ("Unexpected error when killing a timed out process, pool is broken" )
325346 LOGGER .exception (e )
326347 exception = e
0 commit comments