Coverage for src/task/custom_task.py: 84%

129 statements  

« prev     ^ index     » next       coverage.py v7.9.0, created at 2025-10-13 12:26 +0000

1from collections.abc import Callable 

2from typing import Any, Literal, Self 

3 

4from celery import states 

5 

6from history.model_data import HistoryEventDict 

7from history.models import HistoryEvent 

8from history.utils import insert_history_event 

9 

10 

11class TaskError(Exception): 

12 task_index: int 

13 sub_error: Self | Exception # Can be a TaskError 

14 

15 def __init__(self, task_index, sub_error): 

16 self.task_index = task_index 

17 self.sub_error = sub_error 

18 

19 def get_deep_error(self): 

20 if isinstance(self.sub_error, TaskError): 

21 return self.sub_error.get_deep_error() 

22 return self.sub_error 

23 

24 def create_subtask_index_chain(self): 

25 if isinstance(self.sub_error, TaskError): 

26 return [str(self.task_index), *self.sub_error.create_subtask_index_chain()] 

27 return [str(self.task_index)] 

28 

29 def __str__(self): 

30 error = self.get_deep_error() 

31 return ( 

32 f"({','.join(self.create_subtask_index_chain())}) {type(error).__name__}: {str(error)}" 

33 ) 

34 

35 

36class CustomTask: 

37 """ 

38 A custom task is an object that can be executed inside a Celery Task using the run_as_task function 

39 

40 It is a complex object because it can contain subtasks which will be executed in a specific order 

41 

42 This is useful because of the flexibility it offers us. This covers both complex and simple situations since you 

43 can easily convert a function into a CustomTask: 

44 ``` 

45 CustomTask(lambda: print("This is a function")) 

46 ``` 

47 

48 An important feature is the ability to recover from an error. 

49 By serializing the Error and the arguments used to instantiate the Task, you can easily recover from an error. 

50 This doesn't mean that the error will disappear by itself, but it means you can fix the error and then 

51 come back to where the process was without having to restart the whole thing 

52 

53 CustomTask also allows us to have a much more precise progression tracking since we can track each subtask one 

54 by one 

55 

56 If your task is using an external way of tracking progression, you can override the get_progression method of 

57 of your task and implement your own way of keeping track of the progression. 

58 """ 

59 

60 name: str 

61 state: Literal["pending", "running", "finished", "crashed"] = "pending" 

62 current_index: int = 0 # Current subtask index 

63 

64 subtasks: list[Callable | Self] | None = None 

65 parent: Self | None = None 

66 

67 event_dict: HistoryEventDict | None = ( 

68 None # Used to insert the HistoryEvent at the end of the task 

69 ) 

70 

71 def __init__(self, func: Callable | None = None, *args): 

72 """ 

73 Convert a regular function into a CustomTask 

74 

75 In the case of inheritance, you can override the do method to define the task 

76 If you override the __init__ and the class is used as a subtask, you should pay attention to the subtask 

77 argument passing 

78 

79 If the __init__ asks for no arguments, they will be passed to the do method 

80 If the __init__ asks for arguments, they will be passed to the __init__ and the do method will be 

81 called with no arguments 

82 

83 If you want to force the do method to be called with the arguments, you can give a CustomTask object instead 

84 of the class inside _make_subtasks 

85 """ 

86 if func: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 self.do = func 

88 

89 def make_progress_data(self): 

90 """ 

91 make_progress_data is used to create a dictionary containing the data that will be sent to the frontend at 

92 each progression update 

93 """ 

94 return {} 

95 

96 def get_progression(self, precise=True) -> float: # -1 or [0..1] 

97 if self.state == "crashed": 

98 return -1 

99 

100 if self.state == "pending": 

101 return 0 

102 

103 if self.state == "finished": 

104 return 1 

105 

106 subtasks = self.get_tasks() 

107 subtask_count = len(subtasks) 

108 

109 done_subtask_count = self.current_index 

110 progression = done_subtask_count / subtask_count 

111 

112 if not precise or done_subtask_count >= subtask_count: 

113 return progression 

114 

115 subtask = subtasks[self.current_index] 

116 subtask_progression = subtask.get_progression() if isinstance(subtask, CustomTask) else 0 

117 

118 progression += subtask_progression / len(subtasks) 

119 

120 return progression 

121 

122 def do(self, *args, **kwargs) -> Any: ... 

123 

124 @staticmethod 

125 def progress_callback( 

126 progress: int, progress_text: str | dict | None = None, state=states.STARTED 

127 ): 

128 """ 

129 Called by runners to update the progress. 

130 Injected into the function passed into `run_as_task` 

131 """ 

132 

133 celery_progress_mapping = { 

134 "crashed": states.FAILURE, 

135 "pending": states.PENDING, 

136 "finished": states.SUCCESS, 

137 "running": states.STARTED, 

138 } 

139 

140 def update_progress(self): 

141 progression = self.get_progression() 

142 progression = round(progression * 100) 

143 

144 data = self.make_progress_data() 

145 

146 self.progress_callback(progression, data, state=self.celery_progress_mapping[self.state]) 

147 

148 def set_progress_callback(self, callback): 

149 self.progress_callback = callback 

150 

151 history_progress_mapping: dict[str, HistoryEvent.EventStatusEnum] = { 

152 "crashed": HistoryEvent.EventStatusEnum.ERROR, 

153 "pending": HistoryEvent.EventStatusEnum.PENDING, 

154 "finished": HistoryEvent.EventStatusEnum.OK, 

155 "running": HistoryEvent.EventStatusEnum.PENDING, 

156 } 

157 

158 def add_history_event(self, *args, **kwargs): 

159 if not self.event_dict: 

160 return 

161 

162 status = HistoryEvent.EventStatusEnum.OK 

163 if self.state in self.history_progress_mapping: 

164 status = self.history_progress_mapping[self.state] 

165 

166 self.event_dict["status"] = status 

167 

168 insert_history_event(self.event_dict) 

169 

170 # do calls count as subtasks 

171 def get_tasks(self) -> list[Callable | Self | type[Self]]: 

172 if self.subtasks is None: 

173 self.subtasks = [self.do, *self._make_subtasks()] 

174 return self.subtasks 

175 

176 # Should not be called directly, but should be overriden in subclasses 

177 def _make_subtasks( 

178 self, 

179 ) -> list[Callable | Self | type[Self]]: # Returns a list containing the subtasks 

180 # Don't do recursive calls in here 

181 return [] # This can include random functions or really complex CustomTask, we don't really care 

182 

183 def recover(self, error: TaskError | Exception, *args): 

184 if not isinstance(error, TaskError): 

185 raise ValueError("The error should be a TaskError") 

186 

187 return self(*args, recover_index=error.task_index, recover_error=error) 

188 

189 def on_error(self, error: Exception) -> bool: 

190 """ 

191 Called when an error occurs in the task 

192 

193 returns True if the error is skippable, False otherwise 

194 """ 

195 return False 

196 

197 def __call__( 

198 self, *args, recover_index: int = 0, recover_error: TaskError | None = None, **kwargs 

199 ): 

200 """ 

201 recover_index is used to recover from an error 

202 """ 

203 

204 self.state = "running" 

205 self.current_index = 0 

206 

207 # This will be used when calling subtasks so that we send the return value of each subtask to the next one 

208 next_args = args 

209 try: 

210 for task in self.get_tasks(): 

211 # We're skipping the subtasks we've already done if we're in recovery mode 

212 if self.current_index < recover_index: 

213 self.current_index += 1 

214 continue 

215 

216 # The goal when calling subtasks is to always have an iterable that we can split to get the arguments 

217 # Kind of tricky, we're kind of fighting against Python syntax here 

218 # Since we work only using positional arguments in tasks, handling task with no arguments 

219 # while also working with possible multiple arguments can be tricky 

220 # This kind of does the work because a split empty tuple will be *nothing* 

221 # So we're converting every value we are given into tuples that we can then split 

222 # 

223 # Example: 

224 # 5 -tuple-> (5,) -split-> 5 

225 # () -tuple-> () -split-> *nothing* 

226 # (5,2) -tuple-> (5,2) -split-> 5, 2 

227 next_args = (next_args,) if not isinstance(next_args, tuple) else next_args 

228 

229 # Creating the dynamic CustomTasks 

230 if isinstance(task, type): 

231 # If the task is a class, we need to instantiate it 

232 # We're checking if the class has an overridden constructor 

233 # If it does, we're calling it with the next arguments (if it asks for some) 

234 # If it doesn't, we're calling the class with no arguments 

235 has_overridden_constructor = task.__init__ is not CustomTask.__init__ 

236 asks_for_args = task.__init__.__code__.co_argcount > 1 

237 if has_overridden_constructor and asks_for_args: 

238 task = task(*next_args, **kwargs) 

239 next_args = () 

240 kwargs = {} 

241 else: 

242 task = task() 

243 # We need to update the subtasks list with the instantiated task 

244 # Or else we will have issues when updating the progression 

245 if self.subtasks: 245 ↛ 248line 245 didn't jump to line 248 because the condition on line 245 was always true

246 self.subtasks[self.current_index] = task 

247 

248 if isinstance(task, CustomTask): 

249 # Useful when calling the subtasks so we can update the parent progression instead of the child one 

250 task.update_progress = self.update_progress 

251 

252 task.parent = self 

253 

254 # If we're in recovery mode and the subtask is a CustomTask we can recover deeper 

255 if recover_error is not None and isinstance(task, CustomTask): 

256 next_args = task.recover(recover_error.sub_error, *next_args) 

257 else: 

258 try: 

259 next_args = task(*next_args, **kwargs) 

260 except Exception as error: 

261 skippable = self.on_error(error) 

262 if not skippable: 

263 raise error 

264 next_args = () 

265 

266 kwargs = {} 

267 

268 # If the return value is not a tuple, we're either converting it into an empty tuple or keeping its value 

269 next_args = next_args if next_args is not None else () 

270 self.current_index += 1 

271 self.update_progress() 

272 except Exception as error: 

273 self.state = "crashed" 

274 self.update_progress() 

275 raise TaskError(self.current_index, error) 

276 

277 self.state = "finished" 

278 self.update_progress() 

279 # Converting empty tuple into None value instead 

280 if next_args == (): 

281 next_args = None 

282 

283 return next_args