Skip to content

Commit 9e4784a

Browse files
committed
Improve threads + taking in account review
1 parent 42e7a0c commit 9e4784a

4 files changed

Lines changed: 139 additions & 70 deletions

File tree

giskard/ml_worker/testing/registry/registry.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict, Optional
2+
13
import hashlib
24
import importlib.util
35
import inspect
@@ -6,16 +8,15 @@
68
import sys
79
import uuid
810
from pathlib import Path
9-
from typing import Optional, Dict
1011

1112
import cloudpickle
1213

1314
from giskard.core.core import SavableMeta
14-
from giskard.settings import expand_env_var, settings
15+
from giskard.settings import settings
1516

1617

1718
def find_plugin_location():
18-
return Path(expand_env_var(settings.home)) / "plugins"
19+
return settings.home_dir / "plugins"
1920

2021

2122
logger = logging.getLogger(__name__)

giskard/utils/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ def wrapper(*args, **kwargs):
1919
return wrapper
2020

2121

22-
class WorkerPool:
22+
NOT_STARTED = "Pool is not started"
23+
24+
25+
class SingletonWorkerPool:
2326
"Utility class to wrap a Process pool"
2427

2528
def __init__(self):
@@ -30,9 +33,7 @@ def start(self, max_workers: int = None):
3033
if self.pool is not None:
3134
return
3235
self.max_workers = max(max_workers, settings.min_workers) if max_workers is not None else os.cpu_count()
33-
LOGGER.info("Starting worker pool with %s workers...", self.max_workers)
3436
self.pool = WorkerPoolExecutor(nb_workers=self.max_workers)
35-
LOGGER.info("Pool is started")
3637

3738
def shutdown(self, wait=True):
3839
if self.pool is None:
@@ -41,21 +42,21 @@ def shutdown(self, wait=True):
4142

4243
def schedule(self, fn, args=None, kwargs=None, timeout=None) -> Future:
4344
if self.pool is None:
44-
raise ValueError("Pool is not started")
45+
raise RuntimeError(NOT_STARTED)
4546
return self.pool.schedule(fn, args=args, kwargs=kwargs, timeout=timeout)
4647

4748
def submit(self, *args, **kwargs) -> Future:
4849
if self.pool is None:
49-
raise ValueError("Pool is not started")
50+
raise RuntimeError(NOT_STARTED)
5051
return self.pool.submit(*args, **kwargs)
5152

5253
def map(self, *args, **kwargs):
5354
if self.pool is None:
54-
raise ValueError("Pool is not started")
55+
raise RuntimeError(NOT_STARTED)
5556
return self.pool.map(*args, **kwargs)
5657

5758

58-
POOL = WorkerPool()
59+
POOL = SingletonWorkerPool()
5960

6061

6162
def start_pool(max_workers: int = None):

giskard/utils/worker_pool.py

Lines changed: 80 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from multiprocessing import Process, Queue, SimpleQueue, cpu_count, get_context
1313
from multiprocessing.context import SpawnContext, SpawnProcess
1414
from multiprocessing.managers import SyncManager
15+
from queue import Empty
1516
from threading import Thread
1617
from time import sleep
1718
from uuid import uuid4
@@ -72,13 +73,12 @@ class GiskardTask:
7273
@dataclass(frozen=True)
7374
class GiskardResult:
7475
id: str
76+
logs: str
7577
result: Any = None
7678
exception: Any = None
77-
logs: str = None
7879

7980

8081
def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, running_process: Dict[str, str]):
81-
# TODO(Bazire): For now, worker stops because the queues will be closed. Not the cleanest...
8282
pid = os.getpid()
8383

8484
while True:
@@ -93,7 +93,7 @@ def _process_worker(tasks_queue: SimpleQueue, tasks_results: SimpleQueue, runnin
9393
handler = logging.StreamHandler(f)
9494
logging.getLogger().addHandler(handler)
9595
try:
96-
LOGGER.debug("Doing task %", task.id)
96+
LOGGER.debug("Doing task %s", task.id)
9797
running_process[task.id] = pid
9898
result = task.callable(*task.args, **task.kwargs)
9999
to_return = GiskardResult(id=task.id, result=result, logs=f.getvalue())
@@ -119,7 +119,8 @@ class PoolState(Enum):
119119

120120

121121
class WorkerPoolExecutor(Executor):
122-
def __init__(self, nb_workers: Optional[int] = None):
122+
def __init__(self, nb_workers: Optional[int] = None, name: Optional[str] = None):
123+
self._prefix = f"{name}_" if name is not None else "giskard_pool_"
123124
if nb_workers is None:
124125
nb_workers = cpu_count()
125126
if nb_workers <= 0:
@@ -135,33 +136,42 @@ def __init__(self, nb_workers: Optional[int] = None):
135136
# Mapping of the running tasks and worker pids
136137
self._running_process: Dict[str, str] = self._manager.dict()
137138
# Mapping of the running tasks and worker pids
138-
self._with_timeout_tasks: List[Tuple[str, float]] = []
139-
# Queue with tasks to run
140-
self._pending_tasks_queue: SimpleQueue[GiskardTask] = self._mp_context.SimpleQueue()
139+
self._with_timeout_tasks: List[TimeoutData] = []
141140
# Queue with tasks to run
141+
self._pending_tasks_queue: Queue[GiskardTask] = self._mp_context.Queue()
142+
# Queue with tasks to be consumed asap
142143
# As in ProcessPool, add one more to avoid idling process
143-
self._running_tasks_queue: Queue[GiskardTask] = self._mp_context.Queue(maxsize=self._nb_workers + 1)
144+
self._running_tasks_queue: Queue[Optional[GiskardTask]] = self._mp_context.Queue(maxsize=self._nb_workers + 1)
144145
# Queue with results to notify
145-
self._tasks_results: SimpleQueue[GiskardResult] = self._mp_context.SimpleQueue()
146+
self._tasks_results: Queue[GiskardResult] = self._mp_context.Queue()
146147
# Mapping task_id with future
147148
self._futures_mapping: Dict[str, Future] = dict()
149+
LOGGER.debug("Starting threads for the WorkerPoolExecutor")
148150

149151
self._threads = [
150-
Thread(name="giskard_pool_" + target.__name__, target=target, daemon=True, args=[self], kwargs=None)
152+
Thread(name=f"{self._prefix}{target.__name__}", target=target, daemon=True, args=[self], kwargs=None)
151153
for target in [_killer_thread, _feeder_thread, _results_thread]
152154
]
153155
for t in self._threads:
154156
t.start()
157+
LOGGER.debug("Threads started, spawning workers...")
158+
LOGGER.debug("Starting the pool with %s", {self._nb_workers})
155159

156160
# Startup workers
157161
for _ in range(self._nb_workers):
158162
self._spawn_worker()
163+
LOGGER.info("WorkerPoolExecutor is started")
164+
165+
def health_check(self):
166+
if any([not p.is_alive() for p in self._processes.values()]):
167+
LOGGER.warning("At least one process died for an unknown reason, marking pool as broken")
168+
self._state = PoolState.BROKEN
159169

160170
def _spawn_worker(self):
161171
# Daemon means process are linked to main one, and will be stopped if current process is stopped
162172
p = self._mp_context.Process(
163173
target=_process_worker,
164-
name="giskard_worker_process",
174+
name=f"{self._prefix}_worker_process",
165175
args=(self._running_tasks_queue, self._tasks_results, self._running_process),
166176
daemon=True,
167177
)
@@ -190,7 +200,7 @@ def schedule(
190200
self._futures_mapping[task.id] = res
191201
self._pending_tasks_queue.put(task)
192202
if timeout is not None:
193-
self._with_timeout_tasks.append((task.id, time.monotonic() + timeout))
203+
self._with_timeout_tasks.append(TimeoutData(task.id, time.monotonic() + timeout))
194204
return res
195205

196206
def shutdown(self, wait=True, timeout: float = 5):
@@ -207,8 +217,14 @@ def shutdown(self, wait=True, timeout: float = 5):
207217
while not self._running_tasks_queue.empty():
208218
try:
209219
self._running_tasks_queue.get_nowait()
210-
except BaseException:
211-
pass
220+
except ValueError as e:
221+
# This happens if queues is closed
222+
LOGGER.warning("Running task queue is already closed")
223+
LOGGER.exception(e)
224+
except Empty as e:
225+
# May happen if a process consume an element
226+
LOGGER.warning("Queue was empty, skipping")
227+
LOGGER.exception(e)
212228
# Try to nicely stop the worker, by adding None into the running tasks
213229
for _ in range(self._nb_workers):
214230
self._running_tasks_queue.put(None, timeout=1)
@@ -235,92 +251,97 @@ def shutdown(self, wait=True, timeout: float = 5):
235251
return exit_codes
236252

237253

254+
def _safe_get(queue: Queue, executor: WorkerPoolExecutor, timeout: float = 1) -> Tuple[Any, bool]:
255+
try:
256+
result = queue.get(timeout=1)
257+
except Empty:
258+
result = None
259+
except ValueError as e:
260+
# If queue is closed
261+
if executor._state not in FINAL_STATES:
262+
LOGGER.error("Queue is closed, and executor not in final state")
263+
executor._state = PoolState.BROKEN
264+
raise e
265+
return None, True
266+
if executor._state in FINAL_STATES:
267+
return None, True
268+
return result, False
269+
270+
238271
def _results_thread(
239272
executor: WorkerPoolExecutor,
240273
):
241274
# Goal of this thread is to feed the running tasks from pending one as soon as possible
242-
while True:
243-
while executor._state not in FINAL_STATES and (executor._tasks_results.empty()):
244-
# TODO(Bazire): find a way to improve this ?
245-
# Cannot use select, since we want to be windows compatible
246-
sleep(0.01)
247-
if executor._state in FINAL_STATES:
275+
# while True:
276+
while executor._state not in FINAL_STATES:
277+
result, should_stop = _safe_get(executor._tasks_results, executor)
278+
if should_stop:
248279
return
249-
result = executor._tasks_results.get()
250-
future = executor._futures_mapping.get(result.id)
251-
if future.cancelled():
252-
try:
253-
del executor._futures_mapping[result.id]
254-
except BaseException:
255-
pass
280+
if result is None:
281+
continue
282+
283+
future = executor._futures_mapping.pop(result.id, None)
284+
if future is None or future.cancelled():
256285
continue
257286
future.logs = result.logs
258287
if result.exception is None:
259288
future.set_result(result.result)
260289
else:
261-
# TODO(Bazire): improve to get Traceback
262290
future.set_exception(RuntimeError(result.exception))
263-
try:
264-
del executor._futures_mapping[result.id]
265-
except BaseException:
266-
pass
267291

268292

269293
def _feeder_thread(
270294
executor: WorkerPoolExecutor,
271295
):
272296
# Goal of this thread is to feed the running tasks from pending one as soon as possible
273-
while True:
274-
while executor._state not in FINAL_STATES and (
275-
executor._running_tasks_queue.full() or executor._pending_tasks_queue.empty()
276-
):
277-
# TODO(Bazire): find a way to improve this ?
278-
# Cannot use select, since we want to be windows compatible
279-
sleep(0.01)
280-
if executor._state in FINAL_STATES:
297+
while executor._state not in FINAL_STATES:
298+
task, should_stop = _safe_get(executor._pending_tasks_queue, executor)
299+
if should_stop:
281300
return
282-
task = executor._pending_tasks_queue.get()
301+
if task is None:
302+
continue
303+
283304
future = executor._futures_mapping.get(task.id)
284-
if future is not None and future.set_running_or_notify_cancel():
305+
if future is None:
306+
continue
307+
if future.set_running_or_notify_cancel():
285308
executor._running_tasks_queue.put(task)
286-
elif future is not None:
309+
else:
287310
# Future has been cancelled already, nothing to do
288-
del executor._futures_mapping[task.id]
311+
executor._futures_mapping.pop(task.id, False)
289312

290313

291314
def _killer_thread(
292315
executor: WorkerPoolExecutor,
293316
):
294-
while True:
317+
while executor._state not in FINAL_STATES:
295318
while len(executor._with_timeout_tasks) == 0 and executor._state not in FINAL_STATES:
296319
# No need to be too active
297320
sleep(1)
321+
executor.health_check()
298322
if executor._state in FINAL_STATES:
299323
return
300324

301-
clean_up = []
325+
clean_up: List[TimeoutData] = []
302326
exception = None
303327
for timeout_data in executor._with_timeout_tasks:
304-
task_id, end_time = timeout_data
305-
if task_id not in executor._futures_mapping:
328+
if timeout_data.id not in executor._futures_mapping:
306329
# Task is already completed, do not wait for it
307330
clean_up.append(timeout_data)
308-
elif time.monotonic() > end_time:
331+
elif time.monotonic() > timeout_data.end_time:
309332
# Task has timed out, we should kill it
310333
try:
311-
pid = executor._running_process.get(task_id)
312-
if pid is None:
313-
# Task must have finished
314-
continue
315-
future = executor._futures_mapping.get(task_id)
316-
if not future.cancel():
334+
future = executor._futures_mapping.pop(timeout_data.id, None)
335+
if future is not None and not future.cancel():
317336
LOGGER.warning("Killing a timed out process")
318337
future.set_exception(TimeoutError("Task took too long"))
319-
p = executor._processes[pid]
320-
del executor._processes[pid]
321-
_stop_processes([p])
322-
executor._spawn_worker()
338+
pid = executor._running_process.pop(timeout_data.id, None)
339+
if pid is not None:
340+
p = executor._processes.pop(pid)
341+
_stop_processes([p])
342+
executor._spawn_worker()
323343
except BaseException as e:
344+
# This is probably an OSError, but we want to be extra safe
324345
LOGGER.warning("Unexpected error when killing a timed out process, pool is broken")
325346
LOGGER.exception(e)
326347
exception = e

0 commit comments

Comments
 (0)