본문 바로가기
AI Framework/PyTorch

[PyTorch] 10주차: 모델 배포 및 최종 프로젝트 발표

by cogito21_python 2024. 5. 30.
반응형

강의 목표

  • 모델 배포의 중요성 및 방법 이해
  • PyTorch 모델을 배포하기 위한 다양한 기술 학습
  • 최종 프로젝트 발표 및 코드 리뷰를 통한 실전 감각 향상

강의 내용

1. 모델 배포의 중요성

  • 모델 배포의 개념
    • 학습된 모델을 실제 환경에 적용하여 예측 서비스 제공
    • 배포된 모델은 웹 애플리케이션, 모바일 앱, IoT 디바이스 등에서 사용할 수 있음
  • 배포의 주요 고려 사항
    • 성능 최적화: 예측 속도, 메모리 사용량 등
    • 안정성 및 확장성: 다양한 요청 처리 능력
    • 보안: 데이터 보호 및 접근 제어

2. 모델 배포 방법

  • PyTorch 모델 저장 및 로드
    • 모델 가중치 저장 및 로드
    • 전체 모델 저장 및 로드
     
# 모델 가중치 저장
torch.save(model.state_dict(), 'model_weights.pth')

# 모델 가중치 로드
model = SimpleCNN()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

# 전체 모델 저장
torch.save(model, 'model.pth')

# 전체 모델 로드
model = torch.load('model.pth')
model.eval()
  • ONNX(Open Neural Network Exchange)로 모델 변환
    • PyTorch 모델을 ONNX 형식으로 변환하여 다양한 플랫폼에서 사용
     
import torch.onnx

# 모델을 ONNX 형식으로 변환
dummy_input = torch.randn(1, 1, 28, 28)  # 모델 입력에 맞는 더미 데이터 생성
torch.onnx.export(model, dummy_input, 'model.onnx')

 

3. 배포 플랫폼 소개

  • Flask를 이용한 간단한 웹 서비스 구축
    • Flask는 Python으로 작성된 마이크로 웹 프레임워크로, 간단한 API 서버 구축에 유용
     
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io

app = Flask(__name__)

# 모델 로드
model = torch.load('model.pth')
model.eval()

# 이미지 전처리 함수
def transform_image(image_bytes):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(io.BytesIO(image_bytes))
    return transform(image).unsqueeze(0)

# 예측 함수
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model(tensor)
    _, predicted = torch.max(outputs.data, 1)
    return predicted.item()

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        prediction = get_prediction(img_bytes)
        return jsonify({'prediction': prediction})

if __name__ == '__main__':
    app.run()

 

  • 클라우드 서비스 배포
    • AWS, Google Cloud, Azure 등의 클라우드 플랫폼을 사용하여 모델 배포
    • 주요 서비스: AWS SageMaker, Google AI Platform, Azure Machine Learning

4. 최종 프로젝트 발표

  • 프로젝트 발표
    • 각 그룹 또는 개인은 프로젝트 결과 발표
    • 모델 설계, 훈련 과정, 성능 평가 결과 등을 공유
  • 코드 리뷰 및 피드백
    • 각 그룹의 코드를 리뷰하고 피드백 제공
    • 개선할 점 및 좋은 점 공유

5. 실습 및 과제

  • 실습 과제
    • 학습된 모델을 저장하고, Flask를 이용하여 간단한 웹 서비스 구축
    • 모델 배포 후 테스트 데이터로 예측 수행
     
# 실습 과제 예시
import torch
import torchvision.transforms as transforms
from PIL import Image
import io

# 모델 로드
model = torch.load('model.pth')
model.eval()

# 이미지 전처리 함수
def transform_image(image_bytes):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(io.BytesIO(image_bytes))
    return transform(image).unsqueeze(0)

# 예측 함수
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes)
    outputs = model(tensor)
    _, predicted = torch.max(outputs.data, 1)
    return predicted.item()

# 테스트 이미지 예측
with open('test_image.png', 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes))

 

  • 과제 제출
    • Jupyter Notebook 파일 및 Flask 코드 제출
    • 제출 기한: 다음 강의 시작 전까지

 

반응형