autogen/flaml/automl/training_log.py

180 lines
4.9 KiB
Python

"""!
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
"""
import json
from typing import IO
from contextlib import contextmanager
import logging
logger = logging.getLogger("flaml.automl")
class TrainingLogRecord(object):
def __init__(
self,
record_id: int,
iter_per_learner: int,
logged_metric: float,
trial_time: float,
wall_clock_time: float,
validation_loss: float,
config: dict,
learner: str,
sample_size: int,
):
self.record_id = record_id
self.iter_per_learner = iter_per_learner
self.logged_metric = logged_metric
self.trial_time = trial_time
self.wall_clock_time = wall_clock_time
self.validation_loss = float(validation_loss)
self.config = config
self.learner = learner
self.sample_size = sample_size
def dump(self, fp: IO[str]):
d = vars(self)
return json.dump(d, fp)
@classmethod
def load(cls, json_str: str):
d = json.loads(json_str)
return cls(**d)
def __str__(self):
return json.dumps(vars(self))
class TrainingLogCheckPoint(TrainingLogRecord):
def __init__(self, curr_best_record_id: int):
self.curr_best_record_id = curr_best_record_id
class TrainingLogWriter(object):
def __init__(self, output_filename: str):
self.output_filename = output_filename
self.file = None
self.current_best_loss_record_id = None
self.current_best_loss = float("+inf")
self.current_sample_size = None
self.current_record_id = 0
def open(self):
self.file = open(self.output_filename, "w")
def append_open(self):
self.file = open(self.output_filename, "a")
def append(
self,
it_counter: int,
train_loss: float,
trial_time: float,
wall_clock_time: float,
validation_loss,
config,
learner,
sample_size,
):
if self.file is None:
raise IOError("Call open() to open the output file first.")
if validation_loss is None:
raise ValueError("TEST LOSS NONE ERROR!!!")
record = TrainingLogRecord(
self.current_record_id,
it_counter,
train_loss,
trial_time,
wall_clock_time,
validation_loss,
config,
learner,
sample_size,
)
if (
validation_loss < self.current_best_loss
or validation_loss == self.current_best_loss
and self.current_sample_size is not None
and sample_size > self.current_sample_size
):
self.current_best_loss = validation_loss
self.current_sample_size = sample_size
self.current_best_loss_record_id = self.current_record_id
self.current_record_id += 1
record.dump(self.file)
self.file.write("\n")
self.file.flush()
def checkpoint(self):
if self.file is None:
raise IOError("Call open() to open the output file first.")
if self.current_best_loss_record_id is None:
logger.warning("flaml.training_log: checkpoint() called before any record is written, skipped.")
return
record = TrainingLogCheckPoint(self.current_best_loss_record_id)
record.dump(self.file)
self.file.write("\n")
self.file.flush()
def close(self):
if self.file is not None:
self.file.close()
self.file = None # for pickle
class TrainingLogReader(object):
def __init__(self, filename: str):
self.filename = filename
self.file = None
def open(self):
self.file = open(self.filename)
def records(self):
if self.file is None:
raise IOError("Call open() before reading log file.")
for line in self.file:
data = json.loads(line)
if len(data) == 1:
# Skip checkpoints.
continue
yield TrainingLogRecord(**data)
def close(self):
if self.file is not None:
self.file.close()
self.file = None # for pickle
def get_record(self, record_id) -> TrainingLogRecord:
if self.file is None:
raise IOError("Call open() before reading log file.")
for rec in self.records():
if rec.record_id == record_id:
return rec
raise ValueError(f"Cannot find record with id {record_id}.")
@contextmanager
def training_log_writer(filename: str, append: bool = False):
try:
w = TrainingLogWriter(filename)
if not append:
w.open()
else:
w.append_open()
yield w
finally:
w.close()
@contextmanager
def training_log_reader(filename: str):
try:
r = TrainingLogReader(filename)
r.open()
yield r
finally:
r.close()