Coverage for src/task/custom_task.py: 84%
129 statements
« prev ^ index » next coverage.py v7.9.0, created at 2025-10-13 12:26 +0000
« prev ^ index » next coverage.py v7.9.0, created at 2025-10-13 12:26 +0000
1from collections.abc import Callable
2from typing import Any, Literal, Self
4from celery import states
6from history.model_data import HistoryEventDict
7from history.models import HistoryEvent
8from history.utils import insert_history_event
11class TaskError(Exception):
12 task_index: int
13 sub_error: Self | Exception # Can be a TaskError
15 def __init__(self, task_index, sub_error):
16 self.task_index = task_index
17 self.sub_error = sub_error
19 def get_deep_error(self):
20 if isinstance(self.sub_error, TaskError):
21 return self.sub_error.get_deep_error()
22 return self.sub_error
24 def create_subtask_index_chain(self):
25 if isinstance(self.sub_error, TaskError):
26 return [str(self.task_index), *self.sub_error.create_subtask_index_chain()]
27 return [str(self.task_index)]
29 def __str__(self):
30 error = self.get_deep_error()
31 return (
32 f"({','.join(self.create_subtask_index_chain())}) {type(error).__name__}: {str(error)}"
33 )
36class CustomTask:
37 """
38 A custom task is an object that can be executed inside a Celery Task using the run_as_task function
40 It is a complex object because it can contain subtasks which will be executed in a specific order
42 This is useful because of the flexibility it offers us. This covers both complex and simple situations since you
43 can easily convert a function into a CustomTask:
44 ```
45 CustomTask(lambda: print("This is a function"))
46 ```
48 An important feature is the ability to recover from an error.
49 By serializing the Error and the arguments used to instantiate the Task, you can easily recover from an error.
50 This doesn't mean that the error will disappear by itself, but it means you can fix the error and then
51 come back to where the process was without having to restart the whole thing
53 CustomTask also allows us to have a much more precise progression tracking since we can track each subtask one
54 by one
56 If your task is using an external way of tracking progression, you can override the get_progression method of
57 of your task and implement your own way of keeping track of the progression.
58 """
60 name: str
61 state: Literal["pending", "running", "finished", "crashed"] = "pending"
62 current_index: int = 0 # Current subtask index
64 subtasks: list[Callable | Self] | None = None
65 parent: Self | None = None
67 event_dict: HistoryEventDict | None = (
68 None # Used to insert the HistoryEvent at the end of the task
69 )
71 def __init__(self, func: Callable | None = None, *args):
72 """
73 Convert a regular function into a CustomTask
75 In the case of inheritance, you can override the do method to define the task
76 If you override the __init__ and the class is used as a subtask, you should pay attention to the subtask
77 argument passing
79 If the __init__ asks for no arguments, they will be passed to the do method
80 If the __init__ asks for arguments, they will be passed to the __init__ and the do method will be
81 called with no arguments
83 If you want to force the do method to be called with the arguments, you can give a CustomTask object instead
84 of the class inside _make_subtasks
85 """
86 if func: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 self.do = func
89 def make_progress_data(self):
90 """
91 make_progress_data is used to create a dictionary containing the data that will be sent to the frontend at
92 each progression update
93 """
94 return {}
96 def get_progression(self, precise=True) -> float: # -1 or [0..1]
97 if self.state == "crashed":
98 return -1
100 if self.state == "pending":
101 return 0
103 if self.state == "finished":
104 return 1
106 subtasks = self.get_tasks()
107 subtask_count = len(subtasks)
109 done_subtask_count = self.current_index
110 progression = done_subtask_count / subtask_count
112 if not precise or done_subtask_count >= subtask_count:
113 return progression
115 subtask = subtasks[self.current_index]
116 subtask_progression = subtask.get_progression() if isinstance(subtask, CustomTask) else 0
118 progression += subtask_progression / len(subtasks)
120 return progression
122 def do(self, *args, **kwargs) -> Any: ...
124 @staticmethod
125 def progress_callback(
126 progress: int, progress_text: str | dict | None = None, state=states.STARTED
127 ):
128 """
129 Called by runners to update the progress.
130 Injected into the function passed into `run_as_task`
131 """
133 celery_progress_mapping = {
134 "crashed": states.FAILURE,
135 "pending": states.PENDING,
136 "finished": states.SUCCESS,
137 "running": states.STARTED,
138 }
140 def update_progress(self):
141 progression = self.get_progression()
142 progression = round(progression * 100)
144 data = self.make_progress_data()
146 self.progress_callback(progression, data, state=self.celery_progress_mapping[self.state])
148 def set_progress_callback(self, callback):
149 self.progress_callback = callback
151 history_progress_mapping: dict[str, HistoryEvent.EventStatusEnum] = {
152 "crashed": HistoryEvent.EventStatusEnum.ERROR,
153 "pending": HistoryEvent.EventStatusEnum.PENDING,
154 "finished": HistoryEvent.EventStatusEnum.OK,
155 "running": HistoryEvent.EventStatusEnum.PENDING,
156 }
158 def add_history_event(self, *args, **kwargs):
159 if not self.event_dict:
160 return
162 status = HistoryEvent.EventStatusEnum.OK
163 if self.state in self.history_progress_mapping:
164 status = self.history_progress_mapping[self.state]
166 self.event_dict["status"] = status
168 insert_history_event(self.event_dict)
170 # do calls count as subtasks
171 def get_tasks(self) -> list[Callable | Self | type[Self]]:
172 if self.subtasks is None:
173 self.subtasks = [self.do, *self._make_subtasks()]
174 return self.subtasks
176 # Should not be called directly, but should be overriden in subclasses
177 def _make_subtasks(
178 self,
179 ) -> list[Callable | Self | type[Self]]: # Returns a list containing the subtasks
180 # Don't do recursive calls in here
181 return [] # This can include random functions or really complex CustomTask, we don't really care
183 def recover(self, error: TaskError | Exception, *args):
184 if not isinstance(error, TaskError):
185 raise ValueError("The error should be a TaskError")
187 return self(*args, recover_index=error.task_index, recover_error=error)
189 def on_error(self, error: Exception) -> bool:
190 """
191 Called when an error occurs in the task
193 returns True if the error is skippable, False otherwise
194 """
195 return False
197 def __call__(
198 self, *args, recover_index: int = 0, recover_error: TaskError | None = None, **kwargs
199 ):
200 """
201 recover_index is used to recover from an error
202 """
204 self.state = "running"
205 self.current_index = 0
207 # This will be used when calling subtasks so that we send the return value of each subtask to the next one
208 next_args = args
209 try:
210 for task in self.get_tasks():
211 # We're skipping the subtasks we've already done if we're in recovery mode
212 if self.current_index < recover_index:
213 self.current_index += 1
214 continue
216 # The goal when calling subtasks is to always have an iterable that we can split to get the arguments
217 # Kind of tricky, we're kind of fighting against Python syntax here
218 # Since we work only using positional arguments in tasks, handling task with no arguments
219 # while also working with possible multiple arguments can be tricky
220 # This kind of does the work because a split empty tuple will be *nothing*
221 # So we're converting every value we are given into tuples that we can then split
222 #
223 # Example:
224 # 5 -tuple-> (5,) -split-> 5
225 # () -tuple-> () -split-> *nothing*
226 # (5,2) -tuple-> (5,2) -split-> 5, 2
227 next_args = (next_args,) if not isinstance(next_args, tuple) else next_args
229 # Creating the dynamic CustomTasks
230 if isinstance(task, type):
231 # If the task is a class, we need to instantiate it
232 # We're checking if the class has an overridden constructor
233 # If it does, we're calling it with the next arguments (if it asks for some)
234 # If it doesn't, we're calling the class with no arguments
235 has_overridden_constructor = task.__init__ is not CustomTask.__init__
236 asks_for_args = task.__init__.__code__.co_argcount > 1
237 if has_overridden_constructor and asks_for_args:
238 task = task(*next_args, **kwargs)
239 next_args = ()
240 kwargs = {}
241 else:
242 task = task()
243 # We need to update the subtasks list with the instantiated task
244 # Or else we will have issues when updating the progression
245 if self.subtasks: 245 ↛ 248line 245 didn't jump to line 248 because the condition on line 245 was always true
246 self.subtasks[self.current_index] = task
248 if isinstance(task, CustomTask):
249 # Useful when calling the subtasks so we can update the parent progression instead of the child one
250 task.update_progress = self.update_progress
252 task.parent = self
254 # If we're in recovery mode and the subtask is a CustomTask we can recover deeper
255 if recover_error is not None and isinstance(task, CustomTask):
256 next_args = task.recover(recover_error.sub_error, *next_args)
257 else:
258 try:
259 next_args = task(*next_args, **kwargs)
260 except Exception as error:
261 skippable = self.on_error(error)
262 if not skippable:
263 raise error
264 next_args = ()
266 kwargs = {}
268 # If the return value is not a tuple, we're either converting it into an empty tuple or keeping its value
269 next_args = next_args if next_args is not None else ()
270 self.current_index += 1
271 self.update_progress()
272 except Exception as error:
273 self.state = "crashed"
274 self.update_progress()
275 raise TaskError(self.current_index, error)
277 self.state = "finished"
278 self.update_progress()
279 # Converting empty tuple into None value instead
280 if next_args == ():
281 next_args = None
283 return next_args