@@ -78,17 +78,25 @@ def websocket_log_actor(ml_worker: MLWorkerInfo, req: Dict, *args, **kwargs):
7878WEBSOCKET_ACTORS = dict ((action .name , websocket_log_actor ) for action in MLWorkerAction )
7979
8080
81- def wrapped_handle_result (action : MLWorkerAction , ml_worker : MLWorker , start : float , rep_id : Optional [str ]):
81+ def wrapped_handle_result (
82+ action : MLWorkerAction , ml_worker : MLWorker , start : float , rep_id : Optional [str ], ignore_timeout : bool
83+ ):
8284 def handle_result (future : Union [Future , Callable [..., websocket .WorkerReply ]]):
8385 log_pool_stats ()
8486
8587 info = None # Needs to be defined in case of cancellation
8688
8789 try :
8890 info : websocket .WorkerReply = future .result () if isinstance (future , Future ) else future ()
89- except CancelledError :
90- info : websocket .WorkerReply = websocket .Empty ()
91- logger .warning ("Task for %s has timed out and been cancelled" , action .name )
91+ except CancelledError as e :
92+ if ignore_timeout :
93+ info : websocket .WorkerReply = websocket .Empty ()
94+ logger .warning ("Task for %s has timed out and been cancelled" , action .name )
95+ else :
96+ info : websocket .WorkerReply = websocket .ErrorReply (
97+ error_str = str (e ), error_type = type (e ).__name__ , detail = traceback .format_exc ()
98+ )
99+ logger .warning (e )
92100 except Exception as e :
93101 info : websocket .WorkerReply = websocket .ErrorReply (
94102 error_str = str (e ), error_type = type (e ).__name__ , detail = traceback .format_exc ()
@@ -171,7 +179,7 @@ def parse_and_execute(
171179 )
172180
173181
174- def dispatch_action (callback , ml_worker , action , req , execute_in_pool , timeout = None ):
182+ def dispatch_action (callback , ml_worker , action , req , execute_in_pool , timeout = None , ignore_timeout = False ):
175183 # Parse the response ID
176184 rep_id = req ["id" ] if "id" in req .keys () else None
177185 # Parse the param
@@ -199,7 +207,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=N
199207 )
200208 start = time .process_time ()
201209
202- result_handler = wrapped_handle_result (action , ml_worker , start , rep_id )
210+ result_handler = wrapped_handle_result (action , ml_worker , start , rep_id , ignore_timeout = ignore_timeout )
203211 # If execution should be done in a pool
204212 if execute_in_pool :
205213 logger .debug ("Submitting for action %s '%s' into the pool" , action .name , callback .__name__ )
@@ -227,7 +235,9 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=N
227235 )
228236
229237
230- def websocket_actor (action : MLWorkerAction , execute_in_pool : bool = True , timeout : Optional [float ] = None ):
238+ def websocket_actor (
239+ action : MLWorkerAction , execute_in_pool : bool = True , timeout : Optional [float ] = None , ignore_timeout : bool = False
240+ ):
231241 """
232242 Register a function as an actor to an action from WebSocket connection
233243 """
@@ -238,7 +248,7 @@ def websocket_actor_callback(callback: callable):
238248 logger .debug (f'Registered "{ callback .__name__ } " for ML Worker "{ action .name } "' )
239249
240250 def wrapped_callback (ml_worker : MLWorker , req : dict , * args , ** kwargs ):
241- dispatch_action (callback , ml_worker , action , req , execute_in_pool , timeout )
251+ dispatch_action (callback , ml_worker , action , req , execute_in_pool , timeout , ignore_timeout )
242252
243253 WEBSOCKET_ACTORS [action .name ] = wrapped_callback
244254
@@ -664,7 +674,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg:
664674 return params
665675
666676
667- @websocket_actor (MLWorkerAction .getPush , timeout = 30 )
677+ @websocket_actor (MLWorkerAction .getPush , timeout = 30 , ignore_timeout = True )
668678def get_push (
669679 client : Optional [GiskardClient ], params : websocket .GetPushParam , * args , ** kwargs
670680) -> websocket .GetPushResponse :
0 commit comments