Jetson-nano,  programming

Jetson nanoで画像分類してみた

今回はサムズアップ・ダウン分類をしていきます.親指の上下分類っすね.

Jetson nano(ハード)の準備

今回は,カメラで撮影した画像を分類していくのでJetson nanoにカメラモジュールを載せてからをPCに接続し電源を入れる.

JupyterLabを起動

前回同様,ターミナルからJupyterLabを起動する↓

プログラムの実行

前回の記事でコンテナ作って下さった方はClassificationというフォルダがあるので,その中の classification_interactive.ipynb を開いて実行していきます.

基本的にはそのまま実行していけば,できるはずなのですが一応解説を.

カメラの起動

#接続デバイス(今回はカメラ)ナンバーの確認
!ls -ltrh /dev/video*

返って来るのはこんなのデバイス0を使ってますよ〜ってことっすね

crw-rw---- 1 root video 81, 0 Dec 19 20:38 /dev/video0

で実際に起動して行く〜

#USBでウェブカムを使用してる場合
from jetcam.usb_camera import USBCamera
camera = USBCamera(width=224, height=224, caputure_device=0)

#CSIでカメラ接続してる場合(ラズパイのモジュールなど)
from jetcam.csi_camera import CSICamera
camera = CSICamera(width=224, height=224, capture_device=0)

camera.running = true
print("camera created")

諸設定

データセット作り,分類のための学習を設定

import torchvision.transforms as transforms
from dataset import ImageClassificationDataset

TASK = 'thumbs'

CATEGORIES = ['thumbs_up', 'thumbs_down']

DATASETS = ['A',' B']

TRANSFORMS = tarnsforms.Compose([
           transforms.ColorJitter(0.2, 0.2, 0.2, 0.2 ),
           transforms.Resize((224, 224)),
           transforms.ToTensor(),
           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225] )])

datasets = {}
for name in DATASETS:
       datasets[name] = ImageClassificationDataset('../data/classification/' + TASK + '_' + name, CATEGORIES, TRANSFORMS)

print("{} task with categories defined".format(TASK, CATEGORIES))

データ保存場所を作る

DATA_DIR = '/nvdli-nano/data/classification/'
!mkdir -p {DATA_DIR}

データ収集

import ipywidets
import traitlets
from Ipython.display
from jetcam.utils import bgr8_to_jpeg

#データセットの初期化
dataset = datasets[DATASETS[0]]

#このセルを2回実行したとき,カメラからのコールバックをオブザーブしないようにする
camera.unobserve_all()

#イメージプレビュー作り
camera_widget = ipywidgets.Image()
traitlets.dlink((camera, 'value'), (camera_widget,  'value'),transform=bgr8_to_jpeg)

#widgetを作る
dataset_wiget = ipywidgets.Dropdown(options=DATASETS, description='dataset')
category_widget = ipywidgets.Dropdown(options=dataset.categories, discription='category')
count_woget = ipywidgets.IntText(discription='count')
save_wiget = ipywidgets.Button(discription='add')

#初期化時にカウントを手動で更新する
count_wiget.value = dataset.get_count(category_woget.value)

#アクティブなデータセットを設定
def set_dataset(change):
       global dataset
       dataset = datasets[cahnge['new']]
       count_widget.value = dataset.get_count(category_widget.value)
dataset_widget.observe(set_dataset, name='value')

#new categoryを選択したときカウントをリセットする
def update_count(change):
       count_widget.value = dataset.get_count(change['new'])
category_widget.observe(update_counts, name='value')

#カテゴリーにイメージを保存しカウントをアップデートする
def save(c):
    dataset.save_entry(camera.value, category_widget.value)
    count_widget.value = dataset.get_count(category_widget.value)
save_widget.on_click(save)

data_collection_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget]), dataset_widget, category_widget, count_widget, save_widget
])

# data_collection_widgetを表示する
print("data_collection_widget created")

モデル作り

DeepLearningのフレームワークであるpytochを利用する.

今回は”RESNET 18”をモデルとして利用する

import torch
import torchvision


device = torch.device('cuda')

# RESNET 18(学習済みモデル)を利用
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, len(dataset.categories))
    
model = model.to(device)

model_save_button = ipywidgets.Button(description='save model')
model_load_button = ipywidgets.Button(description='load model')
model_path_widget = ipywidgets.Text(description='model path', value='/nvdli-nano/data/classification/my_model.pth')

def load_model(c):
    model.load_state_dict(torch.load(model_path_widget.value))
model_load_button.on_click(load_model)
    
def save_model(c):
    torch.save(model.state_dict(), model_path_widget.value)
model_save_button.on_click(save_model)

model_widget = ipywidgets.VBox([
    model_path_widget,
    ipywidgets.HBox([model_load_button, model_save_button])
])

# display(model_widget)
print("model configured and model_widget created")

ライブ処理実行

ここでカメラから入力されてきたデータをどう解析するかをリアルタイムに観察可能にする

import threading
import time
from utils import preprocess
import torch.nn.functional as F

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Text(description='prediction')
score_widgets = []
for category in dataset.categories:
    score_widget = ipywidgets.FloatSlider(min=0.0, max=1.0, description=category, orientation='vertical')
    score_widgets.append(score_widget)

def live(state_widget, model, camera, prediction_widget, score_widget):
    global dataset
    while state_widget.value == 'live':
        image = camera.value
        preprocessed = preprocess(image)
        output = model(preprocessed)
        output = F.softmax(output, dim=1).detach().cpu().numpy().flatten()
        category_index = output.argmax()
        prediction_widget.value = dataset.categories[category_index]
        for i, score in enumerate(list(output)):
            score_widgets[i].value = score
            
def start_live(change):
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget, score_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    ipywidgets.HBox(score_widgets),
    prediction_widget,
    state_widget
])

# ive_execution_widgetを表示
print("live_execution_widget created")

学習と評価

BATCH_SIZE = 8

optimizer = torch.optim.Adam(model.parameters())

epochs_widget = ipywidgets.IntText(description='epochs', value=1)
eval_button = ipywidgets.Button(description='evaluate')
train_button = ipywidgets.Button(description='train')
loss_widget = ipywidgets.FloatText(description='loss')
accuracy_widget = ipywidgets.FloatText(description='accuracy')
progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')

def train_eval(is_training):
    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget
    
    try:
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True
        )

        state_widget.value = 'stop'
        train_button.disabled = True
        eval_button.disabled = True
        time.sleep(1)

        if is_training:
            model = model.train()
        else:
            model = model.eval()
        while epochs_widget.value > 0:
            i = 0
            sum_loss = 0.0
            error_count = 0.0
            for images, labels in iter(train_loader):
                # デバイスにデータを送る
                images = images.to(device)
                labels = labels.to(device)

                if is_training:
                    # ゼロ勾配(パラメータ)
                    optimizer.zero_grad()

                # モデルを実行して出力を取得する
                outputs = model(images)

                # ロスを計算
                loss = F.cross_entropy(outputs, labels)

                if is_training:
                    # backpropogation実行して勾配を累積
                    loss.backward()

                    # step optimizer (パラメータを調整)
                    optimizer.step()

                # progressを増やす
                error_count += len(torch.nonzero(outputs.argmax(1) - labels).flatten())
                count = len(labels.flatten())
                i += count
                sum_loss += float(loss)
                progress_widget.value = i / len(dataset)
                loss_widget.value = sum_loss / i
                accuracy_widget.value = 1.0 - error_count / i
                
            if is_training:
                epochs_widget.value = epochs_widget.value - 1
            else:
                break
    except e:
        pass
    model = model.eval()

    train_button.disabled = False
    eval_button.disabled = False
    state_widget.value = 'live'
    
train_button.on_click(lambda c: train_eval(is_training=True))
eval_button.on_click(lambda c: train_eval(is_training=False))
    
train_eval_widget = ipywidgets.VBox([
    epochs_widget,
    progress_widget,
    loss_widget,
    accuracy_widget,
    ipywidgets.HBox([train_button, eval_button])
])

# train_eval_widgetを表示
print("trainer configured and train_eval_widget created")

widgetの表示

下図のようなwidgetが表示される

all_widget = ipywidgets.VBox([
    ipywidgets.HBox([data_collection_widget, live_execution_widget]), 
    train_eval_widget,
    model_widget
])

display(all_widget)
wedgets(https://developer.nvidia.com/)

遊んでみた

いざ実験!!!

サムズアップ・ダウンそれぞれを”add”ボタンをおし写真をとる.公式では33個ずつ撮影していました.

エポック数を調節し”Train”ボタンを押して学習

完了すると,以下で示すように右のバーチャートで判別率とその下で判別内容を見ることができます.

Thumbs-up (引用:https://youtu.be/rSqIvLQ8Meg)

学習が足らないと,以下の画像のように遠くのものなどは判定精度が落ちるますね.

Thumbs classification error  (引用:https://youtu.be/rSqIvLQ8Meg)

自分でも複数回試した結果,距離と角度,背景も需要(当たり前)みたいで60枚くらいとると微妙なとこも判別できるようになりました〜

ってことで今回はこのくらいで...

卒研落ち着いたらまた遊んでいきたいと思います(ってことは来年)


引用元:

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です