반응형
Runner
- MMEngine안의 integrator로 모든 module 및 framework를 책임짐.
Simple Example
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from mmengine.model import BaseModel
from mmengine.evaluator import BaseMetric
from mmengine.registry import MODELS, DATASETS, METRICS
@MODELS.register_module()
class MyAwesomeModel(BaseModel):
def __init__(self, layers=4, activation='relu') -> None:
super().__init__()
if activation == 'relu':
act_type = nn.ReLU
elif activation == 'silu':
act_type = nn.SiLU
elif activation == 'none':
act_type = nn.Identity
else:
raise NotImplementedError
sequence = [nn.Linear(2, 64), act_type()]
for _ in range(layers-1):
sequence.extend([nn.Linear(64, 64), act_type()])
self.mlp = nn.Sequential(*sequence)
self.classifier = nn.Linear(64, 2)
def forward(self, data, labels, mode):
x = self.mlp(data)
x = self.classifier(x)
if mode == 'tensor':
return x
elif mode == 'predict':
return F.softmax(x, dim=1), labels
elif mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
@DATASETS.register_module()
class MyDataset(Dataset):
def __init__(self, is_train, size):
self.is_train = is_train
if self.is_train:
torch.manual_seed(0)
self.labels = torch.randint(0, 2, (size,))
else:
torch.manual_seed(3407)
self.labels = torch.randint(0, 2, (size,))
r = 3 * (self.labels+1) + torch.randn(self.labels.shape)
theta = torch.rand(self.labels.shape) * 2 * torch.pi
self.data = torch.vstack([r*torch.cos(theta), r*torch.sin(theta)]).T
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
@METRICS.register_module()
class Accuracy(BaseMetric):
def __init__(self):
super().__init__()
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(r['correct'] for r in results)
total_size = sum(r['batch_size'] for r in results)
return dict(accuracy=100*total_correct/total_size)
from torch.utils.data import DataLoader, default_collate
from torch.optim import Adam
from mmengine.runner import Runner
runner = Runner(
# your model
model=MyAwesomeModel(
layers=2,
activation='relu'),
# work directory for saving checkpoints and logs
work_dir='exp/my_awesome_model',
# training data
train_dataloader=DataLoader(
dataset=MyDataset(
is_train=True,
size=10000),
shuffle=True,
collate_fn=default_collate,
batch_size=64,
pin_memory=True,
num_workers=2),
# training configurations
train_cfg=dict(
by_epoch=True, # display in epoch number instead of iterations
max_epochs=10,
val_begin=2, # start validation from the 2nd epoch
val_interval=1), # do validation every 1 epoch
# OptimizerWrapper, new concept in MMEngine for richer optimization options
# Default value works fine for most cases. You may check our documentations
# for more details, e.g. 'AmpOptimWrapper' for enabling mixed precision
# training.
optim_wrapper=dict(
optimizer=dict(
type=Adam,
lr=0.001)),
# ParamScheduler to adjust learning rates or momentums during training
param_scheduler=dict(
type='MultiStepLR',
by_epoch=True,
milestones=[4, 8],
gamma=0.1),
# validation data
val_dataloader=DataLoader(
dataset=MyDataset(
is_train=False,
size=1000),
shuffle=False,
collate_fn=default_collate,
batch_size=1000,
pin_memory=True,
num_workers=2),
# validation configurations, usually leave it an empty dict
val_cfg=dict(),
# evaluation metrics and evaluator
val_evaluator=dict(type=Accuracy),
# following are advanced configurations, try to default when not in need
# hooks are advanced usage, try to default when not in need
default_hooks=dict(
# the most commonly used hook for modifying checkpoint saving interval
checkpoint=dict(type='CheckpointHook', interval=1)),
# `luancher` and `env_cfg` responsible for distributed environment
launcher='none',
env_cfg=dict(
cudnn_benchmark=False, # whether enable cudnn_benchmark
backend='nccl', # distributed communication backend
mp_cfg=dict(mp_start_method='fork')), # multiprocessing configs
log_level='INFO',
# load model weights from given path. None for no loading.
load_from=None,
# resume training from the given path
resume=False
)
# start training your model
runner.train()
- 인자는 dict형태로 전달됨. MMEgine's Style of runner construction(manual construction / contruction via registry)
from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.registry import MODELS
@MODELS.regieter_module()
class MyModel(BaseModel):
def __init__(self, arg1, arg2):
pass
runnder = Runner(
model = dict(
type = "MyModel",
arg1 = xxx,
arg2 = xxx),
...
)
model = MyModel(arg1=xxx, arg2=xxx)
runner = Runner(
model = model,
...
)
- Config 파일이 주어졌을 때
from mmengine.config import Config
from mmengine.runner import Runner
config = Config.fromfile('example_config.py')
runner = Runner.from_cfg(config)
runner.train()
Basic Dataflow
- 회색으로 채워진 도형은 서로 다른 데이터 형식, 실선 상자는 module과 method를 나타냄
- 상속과 override로 인해 diagram이 항상 유지되는 것은 아님.
- train시 붉은 실선 경로를 따라 진행. valid나 test시 파란 실선 경로를 따라 진행. debugging시 초록 실선 경로를 따라 진행
- DataLoader, Model과 Evaluator사이의 data format convention
- data_preprocessor를 통과하여 model로 갈때 unpacking이 이루어짐
- data_samples는 model의 outputs를 받지만 data_batch는 data_loader에서 나온 raw data를 받음.
# training
for data_batch in train_dataloader:
data_batch = data_preprocessor(data_batch)
if isinstance(data_batch, dict):
losses = model.forward(**data_batch, mode='loss')
elif isinstance(data_batch, (list, tuple)):
losses = model.forward(*data_batch, mode='loss')
else:
raise TypeError()
# validation
for data_batch in val_dataloader:
data_batch = data_preprocessor(data_batch)
if isinstance(data_batch, dict):
outputs = model.forward(**data_batch, mode='predict')
elif isinstance(data_batch, (list, tuple)):
outputs = model.forward(**data_batch, mode='predict')
else:
raise TypeError()
evaluator.process(data_samples=outputs, data_batch=data_batch)
metrics = evaluator.evaluate(len(val_dataloader.dataset))
- data 전처리는 data transforms module에서 처리를 권장하지만 batch와 관련된 처리는 data_preprocessor에서 구현 가능.
Reference
- [MMEngine: Runner](https://mmengine.readthedocs.io/en/latest/tutorials/runner.html#)
반응형
'OpenMMLab(미공개) > MMEngine' 카테고리의 다른 글
[MMEngine] Tutorial - Dataset과 DataLoader (0) | 2024.01.14 |
---|---|
[MMEngine] 빠른 Tutorial (0) | 2024.01.14 |
[MMEngine] 환경 설정 (0) | 2024.01.14 |
[MMEngine] 개요 (0) | 2024.01.14 |