Coverage for src/task/custom_task.py: 84%

128 statements  

« prev     ^ index     » next       coverage.py v7.9.0, created at 2025-11-03 17:19 +0000

1from collections.abc import Callable 

2from typing import Any, Literal, Self 

3 

4from celery import states 

5 

6from history.model_data import HistoryEventDict, HistoryEventStatus 

7from history.utils import insert_history_event 

8 

9 

10class TaskError(Exception): 

11 task_index: int 

12 sub_error: Self | Exception # Can be a TaskError 

13 

14 def __init__(self, task_index, sub_error): 

15 self.task_index = task_index 

16 self.sub_error = sub_error 

17 

18 def get_deep_error(self): 

19 if isinstance(self.sub_error, TaskError): 

20 return self.sub_error.get_deep_error() 

21 return self.sub_error 

22 

23 def create_subtask_index_chain(self): 

24 if isinstance(self.sub_error, TaskError): 

25 return [str(self.task_index), *self.sub_error.create_subtask_index_chain()] 

26 return [str(self.task_index)] 

27 

28 def __str__(self): 

29 error = self.get_deep_error() 

30 return ( 

31 f"({','.join(self.create_subtask_index_chain())}) {type(error).__name__}: {str(error)}" 

32 ) 

33 

34 

35class CustomTask: 

36 """ 

37 A custom task is an object that can be executed inside a Celery Task using the run_as_task function 

38 

39 It is a complex object because it can contain subtasks which will be executed in a specific order 

40 

41 This is useful because of the flexibility it offers us. This covers both complex and simple situations since you 

42 can easily convert a function into a CustomTask: 

43 ``` 

44 CustomTask(lambda: print("This is a function")) 

45 ``` 

46 

47 An important feature is the ability to recover from an error. 

48 By serializing the Error and the arguments used to instantiate the Task, you can easily recover from an error. 

49 This doesn't mean that the error will disappear by itself, but it means you can fix the error and then 

50 come back to where the process was without having to restart the whole thing 

51 

52 CustomTask also allows us to have a much more precise progression tracking since we can track each subtask one 

53 by one 

54 

55 If your task is using an external way of tracking progression, you can override the get_progression method of 

56 of your task and implement your own way of keeping track of the progression. 

57 """ 

58 

59 name: str 

60 state: Literal["pending", "running", "finished", "crashed"] = "pending" 

61 current_index: int = 0 # Current subtask index 

62 

63 subtasks: list[Callable | Self] | None = None 

64 parent: Self | None = None 

65 

66 event_dict: HistoryEventDict | None = ( 

67 None # Used to insert the HistoryEvent at the end of the task 

68 ) 

69 

70 def __init__(self, func: Callable | None = None, *args): 

71 """ 

72 Convert a regular function into a CustomTask 

73 

74 In the case of inheritance, you can override the do method to define the task 

75 If you override the __init__ and the class is used as a subtask, you should pay attention to the subtask 

76 argument passing 

77 

78 If the __init__ asks for no arguments, they will be passed to the do method 

79 If the __init__ asks for arguments, they will be passed to the __init__ and the do method will be 

80 called with no arguments 

81 

82 If you want to force the do method to be called with the arguments, you can give a CustomTask object instead 

83 of the class inside _make_subtasks 

84 """ 

85 if func: 85 ↛ 86line 85 didn't jump to line 86 because the condition on line 85 was never true

86 self.do = func 

87 

88 def make_progress_data(self): 

89 """ 

90 make_progress_data is used to create a dictionary containing the data that will be sent to the frontend at 

91 each progression update 

92 """ 

93 return {} 

94 

95 def get_progression(self, precise=True) -> float: # -1 or [0..1] 

96 if self.state == "crashed": 

97 return -1 

98 

99 if self.state == "pending": 

100 return 0 

101 

102 if self.state == "finished": 

103 return 1 

104 

105 subtasks = self.get_tasks() 

106 subtask_count = len(subtasks) 

107 

108 done_subtask_count = self.current_index 

109 progression = done_subtask_count / subtask_count 

110 

111 if not precise or done_subtask_count >= subtask_count: 

112 return progression 

113 

114 subtask = subtasks[self.current_index] 

115 subtask_progression = subtask.get_progression() if isinstance(subtask, CustomTask) else 0 

116 

117 progression += subtask_progression / len(subtasks) 

118 

119 return progression 

120 

121 def do(self, *args, **kwargs) -> Any: ... 

122 

123 @staticmethod 

124 def progress_callback( 

125 progress: int, progress_text: str | dict | None = None, state=states.STARTED 

126 ): 

127 """ 

128 Called by runners to update the progress. 

129 Injected into the function passed into `run_as_task` 

130 """ 

131 

132 celery_progress_mapping = { 

133 "crashed": states.FAILURE, 

134 "pending": states.PENDING, 

135 "finished": states.SUCCESS, 

136 "running": states.STARTED, 

137 } 

138 

139 def update_progress(self): 

140 progression = self.get_progression() 

141 progression = round(progression * 100) 

142 

143 data = self.make_progress_data() 

144 

145 self.progress_callback(progression, data, state=self.celery_progress_mapping[self.state]) 

146 

147 def set_progress_callback(self, callback): 

148 self.progress_callback = callback 

149 

150 history_progress_mapping: dict[str, HistoryEventStatus] = { 

151 "crashed": HistoryEventStatus.ERROR, 

152 "pending": HistoryEventStatus.PENDING, 

153 "finished": HistoryEventStatus.OK, 

154 "running": HistoryEventStatus.PENDING, 

155 } 

156 

157 def add_history_event(self, *args, **kwargs): 

158 if not self.event_dict: 

159 return 

160 

161 status = HistoryEventStatus.OK 

162 if self.state in self.history_progress_mapping: 

163 status = self.history_progress_mapping[self.state] 

164 

165 self.event_dict["status"] = status 

166 

167 insert_history_event(self.event_dict) 

168 

169 # do calls count as subtasks 

170 def get_tasks(self) -> list[Callable | Self | type[Self]]: 

171 if self.subtasks is None: 

172 self.subtasks = [self.do, *self._make_subtasks()] 

173 return self.subtasks 

174 

175 # Should not be called directly, but should be overriden in subclasses 

176 def _make_subtasks( 

177 self, 

178 ) -> list[Callable | Self | type[Self]]: # Returns a list containing the subtasks 

179 # Don't do recursive calls in here 

180 return [] # This can include random functions or really complex CustomTask, we don't really care 

181 

182 def recover(self, error: TaskError | Exception, *args): 

183 if not isinstance(error, TaskError): 

184 raise ValueError("The error should be a TaskError") 

185 

186 return self(*args, recover_index=error.task_index, recover_error=error) 

187 

188 def on_error(self, error: Exception) -> bool: 

189 """ 

190 Called when an error occurs in the task 

191 

192 returns True if the error is skippable, False otherwise 

193 """ 

194 return False 

195 

196 def __call__( 

197 self, *args, recover_index: int = 0, recover_error: TaskError | None = None, **kwargs 

198 ): 

199 """ 

200 recover_index is used to recover from an error 

201 """ 

202 

203 self.state = "running" 

204 self.current_index = 0 

205 

206 # This will be used when calling subtasks so that we send the return value of each subtask to the next one 

207 next_args = args 

208 try: 

209 for task in self.get_tasks(): 

210 # We're skipping the subtasks we've already done if we're in recovery mode 

211 if self.current_index < recover_index: 

212 self.current_index += 1 

213 continue 

214 

215 # The goal when calling subtasks is to always have an iterable that we can split to get the arguments 

216 # Kind of tricky, we're kind of fighting against Python syntax here 

217 # Since we work only using positional arguments in tasks, handling task with no arguments 

218 # while also working with possible multiple arguments can be tricky 

219 # This kind of does the work because a split empty tuple will be *nothing* 

220 # So we're converting every value we are given into tuples that we can then split 

221 # 

222 # Example: 

223 # 5 -tuple-> (5,) -split-> 5 

224 # () -tuple-> () -split-> *nothing* 

225 # (5,2) -tuple-> (5,2) -split-> 5, 2 

226 next_args = (next_args,) if not isinstance(next_args, tuple) else next_args 

227 

228 # Creating the dynamic CustomTasks 

229 if isinstance(task, type): 

230 # If the task is a class, we need to instantiate it 

231 # We're checking if the class has an overridden constructor 

232 # If it does, we're calling it with the next arguments (if it asks for some) 

233 # If it doesn't, we're calling the class with no arguments 

234 has_overridden_constructor = task.__init__ is not CustomTask.__init__ 

235 asks_for_args = task.__init__.__code__.co_argcount > 1 

236 if has_overridden_constructor and asks_for_args: 

237 task = task(*next_args, **kwargs) 

238 next_args = () 

239 kwargs = {} 

240 else: 

241 task = task() 

242 # We need to update the subtasks list with the instantiated task 

243 # Or else we will have issues when updating the progression 

244 if self.subtasks: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true

245 self.subtasks[self.current_index] = task 

246 

247 if isinstance(task, CustomTask): 

248 # Useful when calling the subtasks so we can update the parent progression instead of the child one 

249 task.update_progress = self.update_progress 

250 

251 task.parent = self 

252 

253 # If we're in recovery mode and the subtask is a CustomTask we can recover deeper 

254 if recover_error is not None and isinstance(task, CustomTask): 

255 next_args = task.recover(recover_error.sub_error, *next_args) 

256 else: 

257 try: 

258 next_args = task(*next_args, **kwargs) 

259 except Exception as error: 

260 skippable = self.on_error(error) 

261 if not skippable: 

262 raise error 

263 next_args = () 

264 

265 kwargs = {} 

266 

267 # If the return value is not a tuple, we're either converting it into an empty tuple or keeping its value 

268 next_args = next_args if next_args is not None else () 

269 self.current_index += 1 

270 self.update_progress() 

271 except Exception as error: 

272 self.state = "crashed" 

273 self.update_progress() 

274 raise TaskError(self.current_index, error) 

275 

276 self.state = "finished" 

277 self.update_progress() 

278 # Converting empty tuple into None value instead 

279 if next_args == (): 

280 next_args = None 

281 

282 return next_args