본문 바로가기
OpenMMLab(미공개)/MMEngine

[MMEngine] 빠른 Tutorial

by cogito21_python 2024. 1. 14.
반응형

Model 빌드

- MMEngine내에서 모델은 BaseModel을 상속 받음

- 입력값으로 forward 메서드는 mode라는 인자를 받아야함.

  - training에서 mode는 loss, forward 메서드는 loss키를 포함한 dict를 반환

  - validation에서 mode는 predict, forward 메서드는 예측값과 label들을 반환

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

 

Dataset과 DataLoader 빌드

- 학습과 검증에 있어서 Dataset과 DataLoader 생성.

- torchvision에 있는 datasets 사용 가능

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

 

Evaluation Metrics 빌드

- metric은 BaseMetric을 상속 받음.

- process와 compute_metrics 메서드 구현

  - mode가 "predict"라면 process 메서드는 dataset의 output과 다른 output을 받아들임.

  - output data는 데이터의 배치임. 이 과정 이후 데이터의 배치는 self.results 속성에 정보가 기록됨.

  - compute_metrics는 results 파라미터를 받음. compute_metrics에 들어오는 result는 process 메서드 내 정보를 모두 저장함.

  - 평가 metrics 결과를 dict로 반환

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # save the middle result of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # return the dict containing the eval results
        # the key is the name of the metric name
        return dict(accuracy=100 * total_correct / total_size)

 

Runner 빌드와 Task 수행

-  Model, DataLoader와 Metrics를 이용해 Runner를 빌드

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 학습 및 평가를 할 모델
    model=MMResNet50(),
    # 학습 로그와 가중치를 저장할 폴더
    work_dir='./work_dir',
    # Pytorch 데이터 로더
    train_dataloader=train_dataloader,
    # optimize wrapper
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 학습에 필요한 매개 변수
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # PyTorch 데이터 로더
    val_dataloader=val_dataloader,
    # 평가에 필요한 추가적인 파라미터를 담은 validation configs
    val_cfg=dict(),
    # validation evaluator
    val_evaluator=dict(type=Accuracy),
)

runner.train()

 

 

반응형

'OpenMMLab(미공개) > MMEngine' 카테고리의 다른 글

[MMEngine] Tutorial - Dataset과 DataLoader  (0) 2024.01.14
[MMEngine] Tutorial - Runner  (0) 2024.01.14
[MMEngine] 환경 설정  (0) 2024.01.14
[MMEngine] 개요  (0) 2024.01.14