navis
Object Separation (cloth-segmentation) 본문
728x90
1. 목차
- 환경설정
- 모델 : cloth-segmentation
- 코드변경
- 기존 infer.py 파일 input 데이터 처리하게 수정
- 추론 결과
- 결과 비교
- 최종 결과
2. 환경 설정
- AI 모델 테스트 환경
- Ubuntu 22.04(워크스테이션)
- Anaconda
- VS Code
- Python 3.8.18
- 환경설정
git clone <https://github.com/levindabhi/cloth-segmentation.git>
cd cloth-segmentation
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
conda install -c conda-forge tensorboardx
pip install gdown
python setup_model_weights.py
python train.py
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env train.py
python infer.py
3. cloth-segmentation
https://github.com/levindabhi/cloth-segmentation
1. 모델 (cloth-segmentation)
U2-Net Architecture
4. 코드 변경
infer.py
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.base_dataset import Normalize_image
from utils.saving_utils import load_checkpoint_mgpu
from networks import U2NET
device = "cuda"
image_dir = "/home/viking3/AI/Object/cloth_segmentation/input_images/sample/"
result_dir = "/home/viking3/AI/Object/cloth_segmentation/output_images/"
checkpoint_path = '/home/viking3/AI/Object/cloth_segmentation/trained_checkpoint/cloth_segm.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
print(f'Checkpoint path: {checkpoint_path}')
do_palette = True
def get_palette(num_cls):
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
i += 1
lab >>= 3
return palette
transforms_list = [
transforms.ToTensor(),
Normalize_image(0.5, 0.5)
]
transform_rgb = transforms.Compose(transforms_list)
net = U2NET(in_ch=3, out_ch=4)
net = load_checkpoint_mgpu(net, checkpoint_path)
if net is None:
raise RuntimeError(f"Failed to load the checkpoint from {checkpoint_path}")
net = net.to(device)
net = net.eval()
palette = get_palette(4)
images_list = sorted(os.listdir(image_dir))
pbar = tqdm(total=len(images_list))
for image_name in images_list:
img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
image_tensor = transform_rgb(img)
image_tensor = torch.unsqueeze(image_tensor, 0)
output_tensor = net(image_tensor.to(device))
output_tensor = F.log_softmax(output_tensor[0], dim=1)
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
if do_palette:
output_img.putpalette(palette)
output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))
pbar.update(1)
pbar.close()
U2NET 모델의 객체 분리 중 옷을 학습 시켜서 특화한 모델
5. 추론결과
input Video Frame (cloth-segmentation)
output Video Frame (cloth-segmentation)
6. 최종 결과
이 모델은 사람의 몸 동작이 아닌 옷의 객체를 인식하고 분리합니다.
전신 옷은 노란색, 상의는 빨간색, 하의는 녹색으로 구분된 영역을 타겟으로 하여, 학습된 옷의 데이터로 변경하거나, GAN 기반으로 생성해서 채워 넣는 방식을 테스트해 보고 결정해야 할 것 같습니다.
!! 이제부터 체크포인트 다운로드가 불가 하므로 잘 보관해야 합니다.
'AI' 카테고리의 다른 글
우분투 Anaconda Navigator 설치 및 실행 (0) | 2024.05.23 |
---|---|
Video Upscaling (CodeFormer) (0) | 2024.05.22 |
Video Upscaling (VRT) (0) | 2024.05.22 |
Video Upscaling (IART) (0) | 2024.05.22 |
Stable Diffsuion TEST (0) | 2024.04.18 |