mirror of https://github.com/microsoft/autogen.git
91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
import os
|
|
|
|
try:
|
|
from transformers import Seq2SeqTrainer
|
|
except ImportError:
|
|
Seq2SeqTrainer = object
|
|
|
|
|
|
class TrainerForAuto(Seq2SeqTrainer):
|
|
def predict(
|
|
self,
|
|
test_dataset,
|
|
ignore_keys=None,
|
|
metric_key_prefix=None,
|
|
max_length=None,
|
|
num_beams=None,
|
|
):
|
|
if getattr(self, "_is_seq2seq", None):
|
|
return super().predict(
|
|
test_dataset,
|
|
ignore_keys,
|
|
metric_key_prefix=metric_key_prefix,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
)
|
|
else:
|
|
return super(Seq2SeqTrainer, self).predict(test_dataset, ignore_keys, metric_key_prefix)
|
|
|
|
def prediction_step(
|
|
self,
|
|
model,
|
|
inputs,
|
|
prediction_loss_only,
|
|
ignore_keys,
|
|
):
|
|
if getattr(self, "_is_seq2seq", None):
|
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
|
else:
|
|
return super(Seq2SeqTrainer, self).prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
|
|
|
def log(self, logs) -> None:
|
|
if getattr(self, "_is_seq2seq", None):
|
|
super().log(logs)
|
|
else:
|
|
super(Seq2SeqTrainer, self).log(logs)
|
|
if not hasattr(self, "intermediate_results"):
|
|
self.intermediate_results = {}
|
|
|
|
epoch_num = logs.get("epoch", None)
|
|
if epoch_num:
|
|
self.intermediate_results.setdefault(epoch_num, {})
|
|
self.intermediate_results[epoch_num].update(logs)
|
|
|
|
def evaluate(
|
|
self,
|
|
eval_dataset=None,
|
|
ignore_keys=None,
|
|
metric_key_prefix="eval",
|
|
):
|
|
"""Overriding transformers.Trainer.evaluate by saving metrics and checkpoint path."""
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
|
|
ckpt_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
|
|
# TODO: if your task is seq2seq (i.e., SUMMARIZATION), uncomment the code below (add indentation before metrics = eval_dataset...
|
|
|
|
if getattr(self, "_is_seq2seq", None):
|
|
metrics = eval_dataset and super().evaluate(
|
|
eval_dataset,
|
|
ignore_keys,
|
|
metric_key_prefix,
|
|
max_length=self.args.generation_max_length,
|
|
num_beams=self.args.generation_num_beams,
|
|
)
|
|
else:
|
|
metrics = eval_dataset and super(Seq2SeqTrainer, self).evaluate(
|
|
eval_dataset,
|
|
ignore_keys,
|
|
metric_key_prefix,
|
|
)
|
|
if hasattr(self, "ckpt_to_global_step"):
|
|
self.ckpt_to_global_step[ckpt_dir] = self.state.global_step
|
|
if metrics:
|
|
self.ckpt_to_metric[ckpt_dir] = metrics
|
|
else:
|
|
self.ckpt_to_global_step = {ckpt_dir: self.state.global_step}
|
|
self.ckpt_to_metric = {ckpt_dir: metrics} if metrics else {}
|
|
|
|
return metrics
|