Coverage for src/task/custom_task.py: 84%
128 statements
« prev ^ index » next coverage.py v7.9.0, created at 2025-11-03 17:19 +0000
« prev ^ index » next coverage.py v7.9.0, created at 2025-11-03 17:19 +0000
1from collections.abc import Callable
2from typing import Any, Literal, Self
4from celery import states
6from history.model_data import HistoryEventDict, HistoryEventStatus
7from history.utils import insert_history_event
10class TaskError(Exception):
11 task_index: int
12 sub_error: Self | Exception # Can be a TaskError
14 def __init__(self, task_index, sub_error):
15 self.task_index = task_index
16 self.sub_error = sub_error
18 def get_deep_error(self):
19 if isinstance(self.sub_error, TaskError):
20 return self.sub_error.get_deep_error()
21 return self.sub_error
23 def create_subtask_index_chain(self):
24 if isinstance(self.sub_error, TaskError):
25 return [str(self.task_index), *self.sub_error.create_subtask_index_chain()]
26 return [str(self.task_index)]
28 def __str__(self):
29 error = self.get_deep_error()
30 return (
31 f"({','.join(self.create_subtask_index_chain())}) {type(error).__name__}: {str(error)}"
32 )
35class CustomTask:
36 """
37 A custom task is an object that can be executed inside a Celery Task using the run_as_task function
39 It is a complex object because it can contain subtasks which will be executed in a specific order
41 This is useful because of the flexibility it offers us. This covers both complex and simple situations since you
42 can easily convert a function into a CustomTask:
43 ```
44 CustomTask(lambda: print("This is a function"))
45 ```
47 An important feature is the ability to recover from an error.
48 By serializing the Error and the arguments used to instantiate the Task, you can easily recover from an error.
49 This doesn't mean that the error will disappear by itself, but it means you can fix the error and then
50 come back to where the process was without having to restart the whole thing
52 CustomTask also allows us to have a much more precise progression tracking since we can track each subtask one
53 by one
55 If your task is using an external way of tracking progression, you can override the get_progression method of
56 of your task and implement your own way of keeping track of the progression.
57 """
59 name: str
60 state: Literal["pending", "running", "finished", "crashed"] = "pending"
61 current_index: int = 0 # Current subtask index
63 subtasks: list[Callable | Self] | None = None
64 parent: Self | None = None
66 event_dict: HistoryEventDict | None = (
67 None # Used to insert the HistoryEvent at the end of the task
68 )
70 def __init__(self, func: Callable | None = None, *args):
71 """
72 Convert a regular function into a CustomTask
74 In the case of inheritance, you can override the do method to define the task
75 If you override the __init__ and the class is used as a subtask, you should pay attention to the subtask
76 argument passing
78 If the __init__ asks for no arguments, they will be passed to the do method
79 If the __init__ asks for arguments, they will be passed to the __init__ and the do method will be
80 called with no arguments
82 If you want to force the do method to be called with the arguments, you can give a CustomTask object instead
83 of the class inside _make_subtasks
84 """
85 if func: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true
86 self.do = func
88 def make_progress_data(self):
89 """
90 make_progress_data is used to create a dictionary containing the data that will be sent to the frontend at
91 each progression update
92 """
93 return {}
95 def get_progression(self, precise=True) -> float: # -1 or [0..1]
96 if self.state == "crashed":
97 return -1
99 if self.state == "pending":
100 return 0
102 if self.state == "finished":
103 return 1
105 subtasks = self.get_tasks()
106 subtask_count = len(subtasks)
108 done_subtask_count = self.current_index
109 progression = done_subtask_count / subtask_count
111 if not precise or done_subtask_count >= subtask_count:
112 return progression
114 subtask = subtasks[self.current_index]
115 subtask_progression = subtask.get_progression() if isinstance(subtask, CustomTask) else 0
117 progression += subtask_progression / len(subtasks)
119 return progression
121 def do(self, *args, **kwargs) -> Any: ...
123 @staticmethod
124 def progress_callback(
125 progress: int, progress_text: str | dict | None = None, state=states.STARTED
126 ):
127 """
128 Called by runners to update the progress.
129 Injected into the function passed into `run_as_task`
130 """
132 celery_progress_mapping = {
133 "crashed": states.FAILURE,
134 "pending": states.PENDING,
135 "finished": states.SUCCESS,
136 "running": states.STARTED,
137 }
139 def update_progress(self):
140 progression = self.get_progression()
141 progression = round(progression * 100)
143 data = self.make_progress_data()
145 self.progress_callback(progression, data, state=self.celery_progress_mapping[self.state])
147 def set_progress_callback(self, callback):
148 self.progress_callback = callback
150 history_progress_mapping: dict[str, HistoryEventStatus] = {
151 "crashed": HistoryEventStatus.ERROR,
152 "pending": HistoryEventStatus.PENDING,
153 "finished": HistoryEventStatus.OK,
154 "running": HistoryEventStatus.PENDING,
155 }
157 def add_history_event(self, *args, **kwargs):
158 if not self.event_dict:
159 return
161 status = HistoryEventStatus.OK
162 if self.state in self.history_progress_mapping:
163 status = self.history_progress_mapping[self.state]
165 self.event_dict["status"] = status
167 insert_history_event(self.event_dict)
169 # do calls count as subtasks
170 def get_tasks(self) -> list[Callable | Self | type[Self]]:
171 if self.subtasks is None:
172 self.subtasks = [self.do, *self._make_subtasks()]
173 return self.subtasks
175 # Should not be called directly, but should be overriden in subclasses
176 def _make_subtasks(
177 self,
178 ) -> list[Callable | Self | type[Self]]: # Returns a list containing the subtasks
179 # Don't do recursive calls in here
180 return [] # This can include random functions or really complex CustomTask, we don't really care
182 def recover(self, error: TaskError | Exception, *args):
183 if not isinstance(error, TaskError):
184 raise ValueError("The error should be a TaskError")
186 return self(*args, recover_index=error.task_index, recover_error=error)
188 def on_error(self, error: Exception) -> bool:
189 """
190 Called when an error occurs in the task
192 returns True if the error is skippable, False otherwise
193 """
194 return False
196 def __call__(
197 self, *args, recover_index: int = 0, recover_error: TaskError | None = None, **kwargs
198 ):
199 """
200 recover_index is used to recover from an error
201 """
203 self.state = "running"
204 self.current_index = 0
206 # This will be used when calling subtasks so that we send the return value of each subtask to the next one
207 next_args = args
208 try:
209 for task in self.get_tasks():
210 # We're skipping the subtasks we've already done if we're in recovery mode
211 if self.current_index < recover_index:
212 self.current_index += 1
213 continue
215 # The goal when calling subtasks is to always have an iterable that we can split to get the arguments
216 # Kind of tricky, we're kind of fighting against Python syntax here
217 # Since we work only using positional arguments in tasks, handling task with no arguments
218 # while also working with possible multiple arguments can be tricky
219 # This kind of does the work because a split empty tuple will be *nothing*
220 # So we're converting every value we are given into tuples that we can then split
221 #
222 # Example:
223 # 5 -tuple-> (5,) -split-> 5
224 # () -tuple-> () -split-> *nothing*
225 # (5,2) -tuple-> (5,2) -split-> 5, 2
226 next_args = (next_args,) if not isinstance(next_args, tuple) else next_args
228 # Creating the dynamic CustomTasks
229 if isinstance(task, type):
230 # If the task is a class, we need to instantiate it
231 # We're checking if the class has an overridden constructor
232 # If it does, we're calling it with the next arguments (if it asks for some)
233 # If it doesn't, we're calling the class with no arguments
234 has_overridden_constructor = task.__init__ is not CustomTask.__init__
235 asks_for_args = task.__init__.__code__.co_argcount > 1
236 if has_overridden_constructor and asks_for_args:
237 task = task(*next_args, **kwargs)
238 next_args = ()
239 kwargs = {}
240 else:
241 task = task()
242 # We need to update the subtasks list with the instantiated task
243 # Or else we will have issues when updating the progression
244 if self.subtasks: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true
245 self.subtasks[self.current_index] = task
247 if isinstance(task, CustomTask):
248 # Useful when calling the subtasks so we can update the parent progression instead of the child one
249 task.update_progress = self.update_progress
251 task.parent = self
253 # If we're in recovery mode and the subtask is a CustomTask we can recover deeper
254 if recover_error is not None and isinstance(task, CustomTask):
255 next_args = task.recover(recover_error.sub_error, *next_args)
256 else:
257 try:
258 next_args = task(*next_args, **kwargs)
259 except Exception as error:
260 skippable = self.on_error(error)
261 if not skippable:
262 raise error
263 next_args = ()
265 kwargs = {}
267 # If the return value is not a tuple, we're either converting it into an empty tuple or keeping its value
268 next_args = next_args if next_args is not None else ()
269 self.current_index += 1
270 self.update_progress()
271 except Exception as error:
272 self.state = "crashed"
273 self.update_progress()
274 raise TaskError(self.current_index, error)
276 self.state = "finished"
277 self.update_progress()
278 # Converting empty tuple into None value instead
279 if next_args == ():
280 next_args = None
282 return next_args