Fine tuning là gì

1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn hầu hết ai thao tác làm việc với những mã sản phẩm vào deep learning những sẽ nghe/thân quen với quan niệm Transfer learningFine tuning. Khái niệm tổng quát: Transfer learning là tận dụng tri thức học tập được từ 1 vấn đề nhằm áp dụng vào 1 vấn đề tất cả tương quan khác. Một ví dụ đối chọi giản: nỗ lực vị train 1 Model bắt đầu trọn vẹn đến bài bác toán phân loại chó/mèo, fan ta có thể tận dụng tối đa 1 model đã có được train ở ImageNet dataset cùng với hằng triệu hình ảnh. Pre-trained model này sẽ tiến hành train tiếp trên tập dataset chó/mèo, quá trình train này ra mắt nkhô hanh hơn, kết quả hay giỏi rộng. Có không ít hình dáng Transfer learning, những bạn có thể xem thêm trong bài xích này: Tổng hòa hợp Transfer learning. Trong bài bác này, bản thân đang viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine tuning là gì

Hiểu dễ dàng, fine-tuning là các bạn rước 1 pre-trained Mã Sản Phẩm, tận dụng một trong những phần hoặc cục bộ những layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo thành 1 Model bắt đầu. Thường những layer đầu của Mã Sản Phẩm được freeze (đóng góp băng) lại - tức weight những layer này sẽ không biến thành biến hóa cực hiếm trong quá trình train. Lý do vì chưng các layer này đã có khả năng trích xuất đọc tin mức trìu tượng phải chăng , khả năng này được học trường đoản cú quá trình training trước đó. Ta freeze lại để tận dụng tối đa được kỹ năng này với góp việc train ra mắt nkhô cứng hơn (Model chỉ phải update weight nghỉ ngơi những layer cao). Có tương đối nhiều những Object detect mã sản phẩm được chế tạo dựa trên những Classifier Model. VD Retina Mã Sản Phẩm (Object detect) được xây dừng với backbone là Resnet.


*

1.2 Tại sao pytorch nạm vày Keras ?

Chủ đề bài viết hôm nay, mình đang trả lời fine-tuning Resnet50 - 1 pre-trained model được cung ứng sẵn vào torchvision của pytorch. Tại sao là pytorch nhưng mà không hẳn Keras ? Lý vì vì việc fine-tuning mã sản phẩm trong keras cực kỳ dễ dàng. Dưới đấy là 1 đoạn code minc hoạ mang đến Việc xuất bản 1 Unet dựa vào Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Quý khách hàng hoàn toàn có thể thấy, fine-tuning model vào Keras thực thụ vô cùng dễ dàng, dễ có tác dụng, dễ nắm bắt. Việc add thêm các nhánh rất giản đơn vì chưng cú pháp đơn giản dễ dàng. Trong pytorch thì trở lại, xây dừng 1 model Unet tương tự sẽ rất vất vả cùng phức hợp. Người bắt đầu học sẽ gặp trở ngại vị bên trên mạng hiếm hoi các chỉ dẫn mang đến Việc này. Vậy đề xuất bài xích này bản thân đang khuyên bảo cụ thể giải pháp fine-tune vào pytorch để áp dụng vào bài xích toán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?


*

Khi quan sát vào 1 bức ảnh, đôi mắt thường có xu hướng triệu tập chú ý vào 1 vài ba cửa hàng chủ yếu. Hình ảnh trên đó là 1 minc hoạ, màu rubi được sử dụng để bộc lộ mức độ thu hút. Saliency prediction là bài xích tân oán tế bào rộp sự tập trung của đôi mắt tín đồ lúc quan tiền gần cạnh 1 bức ảnh. Cụ thể, bài toán yên cầu xây dựng 1 Model, Mã Sản Phẩm này thừa nhận hình họa nguồn vào, trả về 1 mask tế bào phỏng cường độ lôi cuốn. vì thế, Mã Sản Phẩm nhận vào 1 đầu vào image với trả về 1 mask bao gồm size tương tự.


*

Để rõ hơn về bài xích toán này, bạn cũng có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataphối thông dụng nhất: SALICON DATASET

2.2 Unet

Note: Quý Khách có thể làm lơ phần này giả dụ vẫn biết về Unet

Đây là 1 bài xích toán thù Image-to-Image. Để giải quyết và xử lý bài toán này, bản thân sẽ xây dựng 1 Model theo bản vẽ xây dựng Unet. Unet là một trong những phong cách thiết kế được áp dụng những vào bài toán thù Image-to-image như: semantic segmentation, tự động hóa color, super resolution ... Kiến trúc của Unet tất cả điểm tương tự như với bản vẽ xây dựng Encoder-Decoder đối xứng, có thêm các skip connection tự Encode quý phái Decode khớp ứng. Về cơ bạn dạng, những layer càng tốt càng trích xuất đọc tin ở tại mức trìu tượng cao, điều này đồng nghĩa với vấn đề những biết tin nút trìu tượng thấp nlỗi mặt đường đường nét, màu sắc, độ phân giải... sẽ bị mất non đi trong quá trình Viral. Người ta thêm những skip-connection vào để giải quyết vụ việc này.

Với phần Encode, feature-map được downscale bởi những Convolution. Ngược lại, ở vị trí decode, feature-map được upscale vì chưng những Upsampling layer, vào bài này bản thân áp dụng các Convolution Transpose.


2.3 Resnet

Để giải quyết và xử lý bài toán thù, bản thân sẽ xây dựng dựng Model Unet với backbone là Resnet50. quý khách hàng đề xuất mày mò về Resnet giả dụ không biết về phong cách xây dựng này. Hãy quan tiền gần kề hình minh hoạ sau đây. Resnet50 được tạo thành những khối bự . Unet được gây ra với Encoder là Resnet50. Ta vẫn kéo ra output của từng khối, tạo nên các skip-connection kết nối từ Encoder quý phái Decoder. Decoder được tạo do những Convolution Transpose layer (xen kẹt trong số đó là những lớp Convolution nhằm mục tiêu bớt số chanel của feature bản đồ -> bớt con số weight đến model).


Theo cách nhìn cá nhân, pytorch rất đơn giản code, dễ nắm bắt hơn không hề ít đối với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, câu hỏi fine-tuning model trong pytorch lại cực nhọc rộng tương đối nhiều đối với Keras. Trong Keras, ta không nên thừa quan tâm tới kiến trúc, luồng cách xử lý của model, chỉ việc lấy ra những output tại một số ít layer một mực làm cho skip-connection, ghép nối cùng tạo thành model new.


Trong pytorch thì trở lại, bạn cần hiểu được luồng cách xử lý và copy code các layer mong giữ lại trong model bắt đầu. Hình trên là code của resnet50 vào torchvision. quý khách hoàn toàn có thể xem thêm link: torchvision-resnet50. bởi vậy Lúc desgin Unet nhỏng phong cách thiết kế sẽ diễn tả bên trên, ta buộc phải bảo đảm đoạn code từ bỏ Conv1 -> Layer4 không biến thành chuyển đổi. Hãy đọc phần tiếp theo để nắm rõ rộng.

Xem thêm: " Right Of Way Là Gì ? Định Nghĩa Và Giải Thích Ý Nghĩa Right Of Way

3. Code

Tất cả code của chính mình được gói gọn vào tệp tin notebook Salicon_main.ipynb. Bạn có thể thiết lập về và run code theo link github: github/trungthanhnguyen0502 . Trong nội dung bài viết bản thân sẽ chỉ đưa ra phần đông đoạn code bao gồm.

Import các package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, dữ liệu gồm thứ trường đoản cú dimension khác với Keras/TF/numpy. thường thì cùng với numpy giỏi keras, ảnh bao gồm dimension theo trang bị tự (batchform size,h,w,chanel)(batchkích cỡ, h, w, chanel)(batchkích thước,h,w,chanel). Thđọng trường đoản cú trong Pytorch ngược chở lại là (batchkích cỡ,chanel,h,w)(batchkích cỡ, chanel, h, w)(batchkích thước,chanel,h,w). Mình sẽ xây dựng dựng 2 hàm toTensor và toNumpy để biến hóa tương hỗ giữa nhị format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv & Deconv

Mình sẽ xây dựng dựng 2 function trả về module Convolution và Convolution Transpose (Deconv)

def Deconv(n_input đầu vào, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input đầu vào, n_output, kernel_size=k_form size, stride=stride, padding=padding, bias=False) blochồng = < Tconv, nn.BatchNorm2d(n_output), nn.LeakyReLU(inplace=True), > return nn.Sequential(*block) def Conv(n_input đầu vào, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_đầu vào, n_output, kernel_size=k_kích cỡ, stride=stride, padding=padding, bias=False) blochồng = < conv, nn.BatchNorm2d(n_output), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout) > return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta đang copy những layer phải giữ lại trường đoản cú resnet50 vào unet. Sau kia khởi chế tạo các Conv / Deconv layer và các layer cần thiết.

Forward function: cần đảm bảo an toàn luồng cách xử trí của resnet50 được giữ nguyên tương đương code cội (trừ Fully-connected layer). Sau đó ta ghnghiền nối những layer lại theo bản vẽ xây dựng Unet vẫn mô tả vào phần 2.

Tạo model: bắt buộc load resnet50 và truyền vào Unet. Đừng quên Freeze các layer của resnet50 vào Unet.

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet lớn make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use to lớn reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)Model = Unet(resnet50)model.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataphối trả dấn 1 các mục các image_path và mask_dir, trả về image và mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/")<-1>.split(".")<0> mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample<"image"> mask = sample<"mask"> # to lớn Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transkhung = A.Compose(< A.Resize(width=256,height=256, p=1), A.RandomSizedCrop(<240,256>, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_datamix = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataphối, batch_size=4, shuffle=True, drop_last=True)# Test datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)<:,:,0>img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs()Test result