PyTorchfrozen怎么使用
本篇内容介绍了“PyTorch frozen怎么使用”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!
创新互联建站专注于企业营销型网站建设、网站重做改版、眉县网站定制设计、自适应品牌网站建设、H5响应式网站、购物商城网站建设、集团公司官网建设、成都外贸网站制作、高端网站制作、响应式网页设计等建站业务,价格优惠性价比高,为眉县等各大城市提供网站开发制作服务。
1. pretrain + 一样 lr 都训练
# ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
2. frozen
# ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 冻结卷积层 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
3. 不同学习率
# -*- coding: utf-8 -*- """ # @brief : 模型finetune方法 """ import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from tools.my_dataset import AntsDataset from tools.common_tools2 import set_seed import torchvision.models as models import torchvision BASEDIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use device :{}".format(device)) set_seed(1) # 设置随机种子 label_name = {"ants": 0, "bees": 1} # 参数设置 MAX_EPOCH = 25 BATCH_SIZE = 16 LR = 0.001 log_interval = 10 val_interval = 1 classes = 2 start_epoch = -1 lr_decay_step = 7 # ============================ step 1/5 数据 ============================ data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data") train_dir = os.path.join(data_dir, "train") valid_dir = os.path.join(data_dir, "val") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = AntsDataset(data_dir=train_dir, transform=train_transform) valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ # 1/3 构建模型 resnet18_ft = models.resnet18() # 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 冻结卷积层 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替换fc层 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device) # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================ # 法2 : conv 小学习率 # flag = 0 flag = 1 if flag: fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 内存地址 base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) optimizer = optim.SGD([ {'params': base_params, 'lr': LR * 0}, # 0 {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9) else: optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 选择优化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. resnet18_ft.train() for i, data in enumerate(train_loader): # forward inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().cpu().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i + 1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. # if flag_m1: print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...])) scheduler.step() # 更新学习率 # validate the model if (epoch + 1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. resnet18_ft.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().cpu().sum().numpy() loss_val += loss.item() loss_val_mean = loss_val / len(valid_loader) valid_curve.append(loss_val_mean) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val)) resnet18_ft.train() train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
“PyTorch frozen怎么使用”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注创新互联网站,小编将为大家输出更多高质量的实用文章!
分享名称:PyTorchfrozen怎么使用
文章网址:http://pcwzsj.com/article/ijescg.html