navis

Object Separation (cloth-segmentation) 본문

AI

Object Separation (cloth-segmentation)

menstua 2024. 5. 22. 12:56
728x90

1. 목차

  • 환경설정
  • 모델 : cloth-segmentation
  • 코드변경
    • 기존 infer.py 파일 input 데이터 처리하게 수정
  • 추론 결과
    • 결과 비교
  • 최종 결과

2. 환경 설정

  • AI 모델 테스트 환경
    • Ubuntu 22.04(워크스테이션)
    • Anaconda
    • VS Code
    • Python 3.8.18
  • 환경설정
git clone <https://github.com/levindabhi/cloth-segmentation.git>
cd cloth-segmentation

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
conda install -c conda-forge tensorboardx
pip install gdown

python setup_model_weights.py

python train.py

python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env train.py

python infer.py

3. cloth-segmentation

https://github.com/levindabhi/cloth-segmentation

1. 모델 (cloth-segmentation)

U2-Net Architecture

4. 코드 변경

infer.py

import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET

device = "cuda"

image_dir = "/home/viking3/AI/Object/cloth_segmentation/input_images/sample/"
result_dir = "/home/viking3/AI/Object/cloth_segmentation/output_images/"
checkpoint_path = '/home/viking3/AI/Object/cloth_segmentation/trained_checkpoint/cloth_segm.pth'

if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
print(f'Checkpoint path: {checkpoint_path}')

do_palette = True

def get_palette(num_cls):
    n = num_cls
    palette = [0] * (n * 3)
    for j in range(0, n):
        lab = j
        palette[j * 3 + 0] = 0
        palette[j * 3 + 1] = 0
        palette[j * 3 + 2] = 0
        i = 0
        while lab:
            palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
            palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
            palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
            i += 1
            lab >>= 3
    return palette

transforms_list = [
    transforms.ToTensor(),
    Normalize_image(0.5, 0.5)
]
transform_rgb = transforms.Compose(transforms_list)

net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
if net is None:
    raise RuntimeError(f"Failed to load the checkpoint from {checkpoint_path}")
net = net.to(device)
net = net.eval()

palette = get_palette(4)

images_list = sorted(os.listdir(image_dir))
pbar = tqdm(total=len(images_list))
for image_name in images_list:
    img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
    image_tensor = transform_rgb(img)
    image_tensor = torch.unsqueeze(image_tensor, 0)

    output_tensor = net(image_tensor.to(device))
    output_tensor = F.log_softmax(output_tensor[0], dim=1)
    output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_tensor = torch.squeeze(output_tensor, dim=0)
    output_arr = output_tensor.cpu().numpy()

    output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
    if do_palette:
        output_img.putpalette(palette)
    output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))

    pbar.update(1)

pbar.close()

 

U2NET 모델의 객체 분리 중 옷을 학습 시켜서 특화한 모델

5. 추론결과

input Video Frame (cloth-segmentation)

output Video Frame (cloth-segmentation)

6. 최종 결과

이 모델은 사람의 몸 동작이 아닌 옷의 객체를 인식하고 분리합니다.

전신 옷은 노란색, 상의는 빨간색, 하의는 녹색으로 구분된 영역을 타겟으로 하여, 학습된 옷의 데이터로 변경하거나, GAN 기반으로 생성해서 채워 넣는 방식을 테스트해 보고 결정해야 할 것 같습니다.

!! 이제부터 체크포인트 다운로드가 불가 하므로 잘 보관해야 합니다.

'AI' 카테고리의 다른 글

우분투 Anaconda Navigator 설치 및 실행  (0) 2024.05.23
Video Upscaling (CodeFormer)  (0) 2024.05.22
Video Upscaling (VRT)  (0) 2024.05.22
Video Upscaling (IART)  (0) 2024.05.22
Stable Diffsuion TEST  (0) 2024.04.18