""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import torch import torch.nn as nn from ..misc import MetricLogger, SmoothedValue, reduce_dict def train_one_epoch( model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device ): """ """ model.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) print_freq = 100 header = "Epoch: [{}]".format(epoch) for imgs, labels in metric_logger.log_every(dataloader, print_freq, header): imgs = imgs.to(device) labels = labels.to(device) preds = model(imgs) loss: torch.Tensor = criterion(preds, labels, epoch) optimizer.zero_grad() loss.backward() optimizer.step() if ema is not None: ema.update(model) loss_reduced_values = {k: v.item() for k, v in reduce_dict({"loss": loss}).items()} metric_logger.update(**loss_reduced_values) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} return stats @torch.no_grad() def evaluate(model, criterion, dataloader, device): model.eval() metric_logger = MetricLogger(delimiter=" ") # metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}')) # metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}')) metric_logger.add_meter("acc", SmoothedValue(window_size=1)) metric_logger.add_meter("loss", SmoothedValue(window_size=1)) header = "Test:" for imgs, labels in metric_logger.log_every(dataloader, 10, header): imgs, labels = imgs.to(device), labels.to(device) preds = model(imgs) acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0] loss = criterion(preds, labels) dict_reduced = reduce_dict({"acc": acc, "loss": loss}) reduced_values = {k: v.item() for k, v in dict_reduced.items()} metric_logger.update(**reduced_values) metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} return stats