ППару дней назад коллега попросил сделать логирующий сам себя итератор поверх enumerate. Я попробовал наследоваться напрямую и потерпел неудачу. Я абсолютно забыл, как работает магический метод __new__. Поскольку я был занят, я пообещал себе разобраться с этой проблемой позже. А ларчик открывался очень просто: 18 строк кода — и у меня появилась нужная функциональность.
Изначальная задача
Сначала стоит объяснить, зачем нам вообще понадобился подобный итератор и почему нам не хватило обычного enumerate. Всё дело в том, что у нас в проекте очень много задач, построенных по такому шаблону:
- Получить пачку объектов из базы.
- Написать в лог, сколько объектов получили.
- С каждым полученным объектом сделать что‑либо, отчитываясь о ходе работы после каждого X‑го объекта.
- Написать в лог о завершении задачи.
В Python это выглядит как‑то так:
iterable = get_bunch()
total = len(iterable)
print("total: {}".format(total))
for i, item in enumerate(iterable, start=1):
try:
func(item)
except Exception as e:
print("catch exception: {}".format(e))
if not i % 100:
print("done {} of {}".format(i, total))
print("Done!")
Вся разница между этими задачами — в теле функции func и в сообщениях в лог. Вся эта структура копировалась раз за разом. Так что мы решили избавиться от этого, сделав свой итератор, который бы прятал всё лишнее.
Реализация
Ок, давайте сделаем класс, который будет наследником enumerate. Как я говорил выше, нам придётся переопределить метод __new__, так как enumerate делает это. Согласно документации, если __new__() возвращает экземпляр класса, тогда метод __init__() нового инстанса будет вызываться с теми же аргументами. Так что у меня получилась такая реализация:
class LogEnumerate(enumerate):
def __new__(cls, iterable, start=1, *args, **kwargs):
return super(LogEnumerate, cls).__new__(cls, iterable, start)
def __init__(self, iterable, start=1, step=10,
start_message='', progress_message='', stop_message=''):
self.progress_message = progress_message
self.stop_message = stop_message
self.step = step
self.total = len(iterable)
print(start_message.format(start_message))
def __next__(self):
try:
i, item = super().__next__()
if not i % self.step:
print(self.progress_message.format(i, self.total))
return item
except StopIteration:
print(self.stop_message)
raise