五月天青色头像情侣网名,国产亚洲av片在线观看18女人,黑人巨茎大战俄罗斯美女,扒下她的小内裤打屁股

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

MMNET 微表情識別(CASME2數(shù)據(jù)集)

2023-06-24 18:22 作者:感覺__站不如油管  | 我要投稿

原github地址:https://github.com/muse1998/MMNet

代碼和數(shù)據(jù)集都存在一些問題,經(jīng)過修改后方能夠運行

main.py

CA_block.py

PC_module.py


main.py

# -*- coding: utf-8 -*-
import torch
import math
import numpy as np
import torchvision.models
import torch.utils.data as data
from torchvision import transforms
import CV2
import pandas as pd
import os, torch
import torch.nn as nn
#import image_utils
import argparse, random
from functools import partial

from MMNET.CA_block import resnet18_pos_attention

from PC_module import VisionTransformer_POS

from torchvision.transforms import Resize
torch.set_printoptions(precision=3, edgeitems=14, linewidth=350)



def parse_args():
? ?parser = argparse.ArgumentParser()
? ?parser.add_argument('--raf_path', type=str, default='D:/CASME2/', help='Raf-DB dataset path.')#default='D:/CASME2/'
? ?parser.add_argument('--checkpoint', type=str, default='D:/CASME2/',
? ? ? ? ? ? ? ? ? ? ? ?help='Pytorch checkpoint file path')
? ?parser.add_argument('--pretrained', type=str, default=None,
? ? ? ? ? ? ? ? ? ? ? ?help='Pretrained weights')
? ?parser.add_argument('--beta', type=float, default=0.7, help='Ratio of high importance group in one mini-batch.')
? ?parser.add_argument('--relabel_epoch', type=int, default=1000,
? ? ? ? ? ? ? ? ? ? ? ?help='Relabeling samples on each mini-batch after 10(Default) epochs.')
? ?parser.add_argument('--batch_size', type=int, default=34, help='Batch size.')
? ?parser.add_argument('--optimizer', type=str, default="adam", help='Optimizer, adam or sgd.')
? ?parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate for sgd.')
? ?parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for sgd')
? ?parser.add_argument('--workers', default=0, type=int, help='Number of data loading workers (default: 4)')
? ?parser.add_argument('--epochs', type=int, default=1000, help='Total training epochs.')
? ?parser.add_argument('--drop_rate', type=float, default=0, help='Drop out rate.')
? ?return parser.parse_args()






class RafDataSet(data.Dataset):
? ?def __init__(self, raf_path, phase,num_loso, transform = None, basic_aug = False, transform_norm=None):
? ? ? ?self.phase = phase
? ? ? ?self.transform = transform
? ? ? ?self.raf_path = raf_path
? ? ? ?self.transform_norm = transform_norm
? ? ? ?SUBJECT_COLUMN =0
? ? ? ?NAME_COLUMN = 1
? ? ? ?ONSET_COLUMN = 2
? ? ? ?APEX_COLUMN = 3
? ? ? ?OFF_COLUMN = 4
? ? ? ?LABEL_AU_COLUMN = 5
? ? ? ?LABEL_ALL_COLUMN = 6


? ? ? ?df = pd.read_excel(os.path.join(self.raf_path, 'CASME2-coding-20140508.xlsx'),usecols=[0,1,3,4,5,7,8])
? ? ? ?df['Subject'] = df['Subject'].apply(str)

? ? ? ?if phase == 'train':
? ? ? ? ? ?dataset = df.loc[df['Subject']!=num_loso]
? ? ? ?else:
? ? ? ? ? ?dataset = df.loc[df['Subject'] == num_loso]

? ? ? ?Subject = dataset.iloc[:, SUBJECT_COLUMN].values
? ? ? ?File_names = dataset.iloc[:, NAME_COLUMN].values
? ? ? ?Label_all = dataset.iloc[:, LABEL_ALL_COLUMN].values ?# 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
? ? ? ?Onset_num = dataset.iloc[:, ONSET_COLUMN].values
? ? ? ?Apex_num = dataset.iloc[:, APEX_COLUMN].values
? ? ? ?Offset_num = dataset.iloc[:, OFF_COLUMN].values
? ? ? ?Label_au = dataset.iloc[:, LABEL_AU_COLUMN].values
? ? ? ?self.file_paths_on = []
? ? ? ?self.file_paths_off = []
? ? ? ?self.file_paths_apex = []
? ? ? ?self.label_all = []
? ? ? ?self.label_au = []
? ? ? ?self.sub= []
? ? ? ?self.file_names =[]
? ? ? ?a=0
? ? ? ?b=0
? ? ? ?c=0
? ? ? ?d=0
? ? ? ?e=0
? ? ? ?# use aligned images for training/testing
? ? ? ?for (f,sub,onset,apex,offset,label_all,label_au) in zip(File_names,Subject,Onset_num,Apex_num,Offset_num,Label_all,Label_au):


? ? ? ? ? ?if label_all == 'happiness' or label_all == 'repression' or label_all == 'disgust' or label_all == 'surprise' or label_all == 'fear' or label_all == 'sadness':

? ? ? ? ? ? ? ?self.file_paths_on.append(onset)
? ? ? ? ? ? ? ?self.file_paths_off.append(offset)
? ? ? ? ? ? ? ?self.file_paths_apex.append(apex)
? ? ? ? ? ? ? ?self.sub.append(sub)
? ? ? ? ? ? ? ?self.file_names.append(f)
? ? ? ? ? ? ? ?if label_all == 'happiness':
? ? ? ? ? ? ? ? ? ?self.label_all.append(0)
? ? ? ? ? ? ? ? ? ?a=a+1
? ? ? ? ? ? ? ?elif label_all == 'surprise':
? ? ? ? ? ? ? ? ? ?self.label_all.append(1)
? ? ? ? ? ? ? ? ? ?b=b+1
? ? ? ? ? ? ? ?else:
? ? ? ? ? ? ? ? ? ?self.label_all.append(2)
? ? ? ? ? ? ? ? ? ?c=c+1

? ? ? ? ? ?# label_au =label_au.split("+")
? ? ? ? ? ? ? ?if isinstance(label_au, int):
? ? ? ? ? ? ? ? ? ?self.label_au.append([label_au])
? ? ? ? ? ? ? ?else:
? ? ? ? ? ? ? ? ? ?label_au = label_au.split("+")
? ? ? ? ? ? ? ? ? ?self.label_au.append(label_au)






? ? ? ? ? ?##label

? ? ? ?self.basic_aug = basic_aug
? ? ? ?#self.aug_func = [image_utils.flip_image,image_utils.add_gaussian_noise]

? ?def __len__(self):
? ? ? ?return len(self.file_paths_on)

? ?def __getitem__(self, idx):
? ? ? ?##sampling strategy for training set
? ? ? ?if self.phase == 'train':
? ? ? ? ? ?onset = self.file_paths_on[idx]
? ? ? ? ? ?#onset = onset.astype('int64')
? ? ? ? ? ?apex = self.file_paths_apex[idx]
? ? ? ? ? ?#apex = apex.astype('int64')
? ? ? ? ? ?offset =self.file_paths_off[idx]
? ? ? ? ? ?#offset = offset.astype('int64')

? ? ? ? ? ?on0 = str(random.randint(int(onset), int(onset + int(0.2* (int(apex) - int(onset)) / 4))))
? ? ? ? ? ?# on0 = str(int(onset))
? ? ? ? ? ?on1 = str(
? ? ? ? ? ? ? ?random.randint(int(onset + int(0.9 * (apex - onset) / 4)), int(onset + int(1.1 * (apex - onset) / 4))))
? ? ? ? ? ?on2 = str(
? ? ? ? ? ? ? ?random.randint(int(onset + int(1.8 * (apex - onset) / 4)), int(onset + int(2.2 * (apex - onset) / 4))))
? ? ? ? ? ?on3 = str(random.randint(int(onset + int(2.7 * (apex - onset) / 4)), onset + int(3.3 * (apex - onset) / 4)))
? ? ? ? ? ?# apex0 = str(apex)
? ? ? ? ? ?apex0 = str(
? ? ? ? ? ? ? ?random.randint(int(apex - int(0.15* (apex - onset) / 4)), apex + int(0.15 * (offset - apex) / 4)))
? ? ? ? ? ?off0 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(0.9 * (offset - apex) / 4)), int(apex + int(1.1 * (offset - apex) / 4))))
? ? ? ? ? ?off1 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(1.8 * (offset - apex) / 4)), int(apex + int(2.2 * (offset - apex) / 4))))
? ? ? ? ? ?off2 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(2.9 * (offset - apex) / 4)), int(apex + int(3.1 * (offset - apex) / 4))))
? ? ? ? ? ?off3 = str(random.randint(int(apex + int(3.8 * (offset - apex) / 4)), offset))



? ? ? ? ? ?sub =str(self.sub[idx])
? ? ? ? ? ?f = str(self.file_names[idx])
? ? ? ?else:##sampling strategy for testing set
? ? ? ? ? ?onset = self.file_paths_on[idx]
? ? ? ? ? ?apex = self.file_paths_apex[idx]
? ? ? ? ? ?offset = self.file_paths_off[idx]

? ? ? ? ? ?on0 = str(onset)
? ? ? ? ? ?on1 = str(int(onset + int((apex - onset) / 4)))
? ? ? ? ? ?on2 = str(int(onset + int(2 * (apex - onset) / 4)))
? ? ? ? ? ?on3 = str(int(onset + int(3 * (apex - onset) / 4)))
? ? ? ? ? ?apex0 = str(apex)
? ? ? ? ? ?off0 = str(int(apex + int((offset - apex) / 4)))
? ? ? ? ? ?off1 = str(int(apex + int(2 * (offset - apex) / 4)))
? ? ? ? ? ?off2 = str(int(apex + int(3 * (offset - apex) / 4)))
? ? ? ? ? ?off3 = str(offset)

? ? ? ? ? ?sub = str(self.sub[idx])
? ? ? ? ? ?f = str(self.file_names[idx])


? ? ? ?on0 ='reg_img' + on0 + '.jpg'
? ? ? ?on1 = 'reg_img' + on1 + '.jpg'
? ? ? ?on2 = 'reg_img' + on2 + '.jpg'
? ? ? ?on3 = 'reg_img' + on3 + '.jpg'
? ? ? ?apex0 ='reg_img' + apex0 + '.jpg'
? ? ? ?off0 ='reg_img' + off0 + '.jpg'
? ? ? ?off1='reg_img' + off1 + '.jpg'
? ? ? ?off2 ='reg_img' + off2 + '.jpg'
? ? ? ?off3 = 'reg_img' + off3 + '.jpg'
? ? ? ?path_on0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on0).replace('\\', '/')
? ? ? ?path_on1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on1).replace('\\', '/')
? ? ? ?path_on2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on2).replace('\\', '/')
? ? ? ?path_on3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on3).replace('\\', '/')
? ? ? ?path_apex0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, apex0).replace('\\', '/')
? ? ? ?path_off0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off0).replace('\\', '/')
? ? ? ?path_off1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off1).replace('\\', '/')
? ? ? ?path_off2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off2).replace('\\', '/')
? ? ? ?path_off3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off3).replace('\\', '/')
? ? ? ?"""
? ? ? ?print(path_on0)
? ? ? ?print(path_on1)
? ? ? ?print(path_on2)
? ? ? ?print(path_on3)
? ? ? ?print(path_apex0)
? ? ? ?print(path_off0)
? ? ? ?print(path_off1)
? ? ? ?print(path_off2)
? ? ? ?print(path_off3)
? ? ? ?"""

? ? ? ?image_on0 = CV2.imread(path_on0)
? ? ? ?image_on1= CV2.imread(path_on1)
? ? ? ?image_on2 = CV2.imread(path_on2)
? ? ? ?image_on3 = CV2.imread(path_on3)
? ? ? ?image_apex0 = CV2.imread(path_apex0)
? ? ? ?image_off0 = CV2.imread(path_off0)
? ? ? ?image_off1 = CV2.imread(path_off1)
? ? ? ?image_off2 = CV2.imread(path_off2)
? ? ? ?image_off3 = CV2.imread(path_off3)

? ? ? ?image_on0 = image_on0[:, :, ::-1] # BGR to RGB
? ? ? ?image_on1 = image_on1[:, :, ::-1]
? ? ? ?image_on2 = image_on2[:, :, ::-1]
? ? ? ?image_on3 = image_on3[:, :, ::-1]
? ? ? ?image_off0 = image_off0[:, :, ::-1]
? ? ? ?image_off1 = image_off1[:, :, ::-1]
? ? ? ?image_off2 = image_off2[:, :, ::-1]
? ? ? ?image_off3 = image_off3[:, :, ::-1]
? ? ? ?image_apex0 = image_apex0[:, :, ::-1]

? ? ? ?label_all = self.label_all[idx]
? ? ? ?label_au = self.label_au[idx]

? ? ? ?# normalization for testing and training
? ? ? ?if self.transform is not None:
? ? ? ? ? ?image_on0 = self.transform(image_on0)
? ? ? ? ? ?image_on1 = self.transform(image_on1)
? ? ? ? ? ?image_on2 = self.transform(image_on2)
? ? ? ? ? ?image_on3 = self.transform(image_on3)
? ? ? ? ? ?image_off0 = self.transform(image_off0)
? ? ? ? ? ?image_off1 = self.transform(image_off1)
? ? ? ? ? ?image_off2 = self.transform(image_off2)
? ? ? ? ? ?image_off3 = self.transform(image_off3)
? ? ? ? ? ?image_apex0 = self.transform(image_apex0)
? ? ? ? ? ?ALL = torch.cat(
? ? ? ? ? ? ? ?(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
? ? ? ? ? ? ? ? image_off3), dim=0)
? ? ? ? ? ?## data augmentation for training only
? ? ? ? ? ?if self.transform_norm is not None and self.phase == 'train':
? ? ? ? ? ? ? ?ALL = self.transform_norm(ALL)
? ? ? ? ? ?image_on0 = ALL[0:3, :, :]
? ? ? ? ? ?image_on1 = ALL[3:6, :, :]
? ? ? ? ? ?image_on2 = ALL[6:9, :, :]
? ? ? ? ? ?image_on3 = ALL[9:12, :, :]
? ? ? ? ? ?image_apex0 = ALL[12:15, :, :]
? ? ? ? ? ?image_off0 = ALL[15:18, :, :]
? ? ? ? ? ?image_off1 = ALL[18:21, :, :]
? ? ? ? ? ?image_off2 = ALL[21:24, :, :]
? ? ? ? ? ?image_off3 = ALL[24:27, :, :]


? ? ? ? ? ?temp = torch.zeros(38)
? ? ? ? ? ?for i in label_au:
? ? ? ? ? ? ? ?#print(i)
? ? ? ? ? ? ? ?temp[int(i) - 1] = 1

? ? ? ? ? ?return image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, label_all, temp


def initialize_weight_goog(m, n=''):
? ?if isinstance(m, nn.Conv2d):
? ? ? ?fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
? ? ? ?m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
? ? ? ?if m.bias is not None:
? ? ? ? ? ?m.bias.data.zero_()
? ?elif isinstance(m, nn.BatchNorm2d):
? ? ? ?m.weight.data.fill_(1.0)
? ? ? ?m.bias.data.zero_()
? ?elif isinstance(m, nn.Linear):
? ? ? ?fan_out = m.weight.size(0) ?# fan-out
? ? ? ?fan_in = 0
? ? ? ?if 'routing_fn' in n:
? ? ? ? ? ?fan_in = m.weight.size(1)
? ? ? ?init_range = 1.0 / math.sqrt(fan_in + fan_out)
? ? ? ?m.weight.data.uniform_(-init_range, init_range)
? ? ? ?m.bias.data.zero_()


def criterion2(y_pred, y_true):
? ?y_pred = (1 - 2 * y_true) * y_pred
? ?y_pred_neg = y_pred - y_true * 1e12
? ?y_pred_pos = y_pred - (1 - y_true) * 1e12
? ?zeros = torch.zeros_like(y_pred[..., :1])
? ?y_pred_neg = torch.cat((y_pred_neg, zeros), dim=-1)
? ?y_pred_pos = torch.cat((y_pred_pos, zeros), dim=-1)
? ?neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
? ?pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
? ?return torch.mean(neg_loss + pos_loss)


class MMNet(nn.Module):
? ?def __init__(self):
? ? ? ?super(MMNet, self).__init__()


? ? ? ?self.conv_act = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels=3, out_channels=90*2, kernel_size=3, stride=2,padding=1, bias=False,groups=1),#group=2
? ? ? ? ? ?nn.BatchNorm2d(180),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?)
? ? ? ?self.pos =nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels=3, out_channels=512, kernel_size=1, stride=1, bias=False),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?)
? ? ? ?##Position Calibration Module(subbranch)
? ? ? ?self.vit_pos=VisionTransformer_POS(img_size=14,
? ? ? ?patch_size=1, embed_dim=512, depth=3, num_heads=4, mlp_ratio=2, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6),drop_path_rate=0.3)
? ? ? ?self.resize=Resize([14,14])
? ? ? ?##main branch consisting of CA blocks
? ? ? ?self.main_branch =resnet18_pos_attention()
? ? ? ?self.head1 = nn.Sequential(
? ? ? ? ? ?nn.Dropout(p=0.5),
? ? ? ? ? ?nn.Linear(1 * 112 *112, 38,bias=False),

? ? ? ?)

? ? ? ?self.timeembed = nn.Parameter(torch.zeros(1, 4, 111, 111))

? ? ? ?self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
? ?def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, if_shuffle):
? ? ? ?##onset:x1 apex:x5
? ? ? ?B = x1.shape[0]

? ? ? ?#Position Calibration Module (subbranch)
? ? ? ?POS =self.vit_pos(self.resize(x1)).transpose(1,2).view(B,512,14,14)
? ? ? ?act = x5 -x1
? ? ? ?act=self.conv_act(act)
? ? ? ?#main branch and fusion
? ? ? ?out,_=self.main_branch(act,POS)

? ? ? ?return out





def run_training():

? ?args = parse_args()
? ?imagenet_pretrained = True #是否加載預(yù)訓(xùn)練模型

? ?if not imagenet_pretrained:
? ? ? ?for m in res18.modules():
? ? ? ? ? ?initialize_weight_goog(m)

? ?if args.pretrained:
? ? ? ?print("Loading pretrained weights...", args.pretrained)
? ? ? ?pretrained = torch.load(args.pretrained)
? ? ? ?pretrained_state_dict = pretrained['state_dict']
? ? ? ?model_state_dict = res18.state_dict()
? ? ? ?loaded_keys = 0
? ? ? ?total_keys = 0
? ? ? ?for key in pretrained_state_dict:
? ? ? ? ? ?if ((key == 'module.fc.weight') | (key == 'module.fc.bias')):
? ? ? ? ? ? ? ?pass
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?model_state_dict[key] = pretrained_state_dict[key]
? ? ? ? ? ? ? ?total_keys += 1
? ? ? ? ? ? ? ?if key in model_state_dict:
? ? ? ? ? ? ? ? ? ?loaded_keys += 1
? ? ? ?print("Loaded params num:", loaded_keys)
? ? ? ?print("Total params num:", total_keys)
? ? ? ?res18.load_state_dict(model_state_dict, strict=False)
? ?### data normalization for both training set
? ?data_transforms = transforms.Compose([
? ? ? ?transforms.ToPILImage(),
? ? ? ?transforms.Resize((224, 224)),

? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize(mean=[0.485, 0.456, 0.406],
? ? ? ? ? ? ? ? ? ? ? ? ? ? std=[0.229, 0.224, 0.225]),

? ?])
? ?### data augmentation for training set only
? ?data_transforms_norm = transforms.Compose([

? ? ? ?transforms.RandomHorizontalFlip(p=0.5),
? ? ? ?transforms.RandomRotation(4),
? ? ? ?transforms.RandomCrop(224, padding=4),


? ?])


? ?### data normalization for both teating set
? ?data_transforms_val = transforms.Compose([
? ? ? ?transforms.ToPILImage(),
? ? ? ?transforms.Resize((224, 224)),
? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize(mean=[0.485, 0.456, 0.406],
? ? ? ? ? ? ? ? ? ? ? ? ? ? std=[0.229, 0.224, 0.225])])



? ?criterion = torch.nn.CrossEntropyLoss()
? ?#leave one subject out protocal
? ?LOSO = ['17', '26', '16', '9', '5', '24', '2', '13', '4', '23', '11', '12', '8', '14', '3', '19', '1', '10',
? ? ? ? ? ?'20', '21', '22', '15', '6', '25', '7']

? ?val_now = 0
? ?num_sum = 0
? ?pos_pred_ALL = torch.zeros(3)
? ?pos_label_ALL = torch.zeros(3)
? ?TP_ALL = torch.zeros(3)

? ?for subj in LOSO:
? ? ? ?train_dataset = RafDataSet(args.raf_path, phase='train', num_loso=subj, transform=data_transforms,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? basic_aug=True, transform_norm=data_transforms_norm)
? ? ? ?val_dataset = RafDataSet(args.raf_path, phase='test', num_loso=subj, transform=data_transforms_val)
? ? ? ?train_loader = torch.utils.data.DataLoader(train_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? batch_size=24,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_workers=args.workers,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=True,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? pin_memory=True)
? ? ? ?val_loader = torch.utils.data.DataLoader(val_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? batch_size=24,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_workers=args.workers,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=False,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? pin_memory=True)
? ? ? ?print('num_sub', subj)
? ? ? ?print('Train set size:', train_dataset.__len__())
? ? ? ?print('Validation set size:', val_dataset.__len__())

? ? ? ?max_corr = 0
? ? ? ?max_f1 = 0
? ? ? ?max_pos_pred = torch.zeros(3)
? ? ? ?max_pos_label = torch.zeros(3)
? ? ? ?max_TP = torch.zeros(3)
? ? ? ?##model initialization
? ? ? ?net_all = MMNet()

? ? ? ?params_all = net_all.parameters()

? ? ? ?if args.optimizer == 'adam':
? ? ? ? ? ?optimizer_all = torch.optim.AdamW(params_all, lr=0.0008, weight_decay=0.7)
? ? ? ? ? ?##optimizer for MMNet

? ? ? ?elif args.optimizer == 'sgd':
? ? ? ? ? ?optimizer = torch.optim.SGD(params, args.lr,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?momentum=args.momentum,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?weight_decay=1e-4)
? ? ? ?else:
? ? ? ? ? ?raise ValueError("Optimizer not supported.")
? ? ? ?##lr_decay
? ? ? ?scheduler_all = torch.optim.lr_scheduler.ExponentialLR(optimizer_all, gamma=0.987)

? ? ? ?net_all = net_all.cuda()

? ? ? ?for i in range(1, 100):
? ? ? ? ? ?running_loss = 0.0
? ? ? ? ? ?correct_sum = 0
? ? ? ? ? ?running_loss_MASK = 0.0
? ? ? ? ? ?correct_sum_MASK = 0
? ? ? ? ? ?iter_cnt = 0

? ? ? ? ? ?net_all.train()


? ? ? ? ? ?for batch_i, (
? ? ? ? ? ?image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3,
? ? ? ? ? ?label_all,
? ? ? ? ? ?label_au) in enumerate(train_loader):
? ? ? ? ? ? ? ?batch_sz = image_on0.size(0)
? ? ? ? ? ? ? ?b, c, h, w = image_on0.shape
? ? ? ? ? ? ? ?iter_cnt += 1

? ? ? ? ? ? ? ?image_on0 = image_on0.cuda()
? ? ? ? ? ? ? ?image_on1 = image_on1.cuda()
? ? ? ? ? ? ? ?image_on2 = image_on2.cuda()
? ? ? ? ? ? ? ?image_on3 = image_on3.cuda()
? ? ? ? ? ? ? ?image_apex0 = image_apex0.cuda()
? ? ? ? ? ? ? ?image_off0 = image_off0.cuda()
? ? ? ? ? ? ? ?image_off1 = image_off1.cuda()
? ? ? ? ? ? ? ?image_off2 = image_off2.cuda()
? ? ? ? ? ? ? ?image_off3 = image_off3.cuda()
? ? ? ? ? ? ? ?label_all = label_all.cuda()
? ? ? ? ? ? ? ?label_au = label_au.cuda()


? ? ? ? ? ? ? ?##train MMNet
? ? ? ? ? ? ? ?ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? image_off2, image_off3, False)

? ? ? ? ? ? ? ?loss_all = criterion(ALL, label_all)

? ? ? ? ? ? ? ?optimizer_all.zero_grad()

? ? ? ? ? ? ? ?loss_all.backward()

? ? ? ? ? ? ? ?optimizer_all.step()
? ? ? ? ? ? ? ?running_loss += loss_all
? ? ? ? ? ? ? ?_, predicts = torch.max(ALL, 1)
? ? ? ? ? ? ? ?correct_num = torch.eq(predicts, label_all).sum()
? ? ? ? ? ? ? ?correct_sum += correct_num






? ? ? ? ? ?## lr decay
? ? ? ? ? ?if i <= 50:

? ? ? ? ? ? ? ?scheduler_all.step()
? ? ? ? ? ?if i>=0:
? ? ? ? ? ? ? ?acc = correct_sum.float() / float(train_dataset.__len__())

? ? ? ? ? ? ? ?running_loss = running_loss / iter_cnt

? ? ? ? ? ? ? ?print('[Epoch %d] Training accuracy: %.4f. Loss: %.3f' % (i, acc, running_loss))


? ? ? ? ? ?pos_label = torch.zeros(3)
? ? ? ? ? ?pos_pred = torch.zeros(3)
? ? ? ? ? ?TP = torch.zeros(3)
? ? ? ? ? ?##test
? ? ? ? ? ?with torch.no_grad():
? ? ? ? ? ? ? ?running_loss = 0.0
? ? ? ? ? ? ? ?iter_cnt = 0
? ? ? ? ? ? ? ?bingo_cnt = 0
? ? ? ? ? ? ? ?sample_cnt = 0
? ? ? ? ? ? ? ?pre_lab_all = []
? ? ? ? ? ? ? ?Y_test_all = []
? ? ? ? ? ? ? ?net_all.eval()
? ? ? ? ? ? ? ?# net_au.eval()
? ? ? ? ? ? ? ?for batch_i, (
? ? ? ? ? ? ? ?image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
? ? ? ? ? ? ? ?image_off3, label_all,
? ? ? ? ? ? ? ?label_au) in enumerate(val_loader):
? ? ? ? ? ? ? ? ? ?batch_sz = image_on0.size(0)
? ? ? ? ? ? ? ? ? ?b, c, h, w = image_on0.shape

? ? ? ? ? ? ? ? ? ?image_on0 = image_on0.cuda()
? ? ? ? ? ? ? ? ? ?image_on1 = image_on1.cuda()
? ? ? ? ? ? ? ? ? ?image_on2 = image_on2.cuda()
? ? ? ? ? ? ? ? ? ?image_on3 = image_on3.cuda()
? ? ? ? ? ? ? ? ? ?image_apex0 = image_apex0.cuda()
? ? ? ? ? ? ? ? ? ?image_off0 = image_off0.cuda()
? ? ? ? ? ? ? ? ? ?image_off1 = image_off1.cuda()
? ? ? ? ? ? ? ? ? ?image_off2 = image_off2.cuda()
? ? ? ? ? ? ? ? ? ?image_off3 = image_off3.cuda()
? ? ? ? ? ? ? ? ? ?label_all = label_all.cuda()
? ? ? ? ? ? ? ? ? ?label_au = label_au.cuda()

? ? ? ? ? ? ? ? ? ?##test
? ? ? ? ? ? ? ? ? ?ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, False)


? ? ? ? ? ? ? ? ? ?loss = criterion(ALL, label_all)
? ? ? ? ? ? ? ? ? ?running_loss += loss
? ? ? ? ? ? ? ? ? ?iter_cnt += 1
? ? ? ? ? ? ? ? ? ?_, predicts = torch.max(ALL, 1)
? ? ? ? ? ? ? ? ? ?correct_num = torch.eq(predicts, label_all)
? ? ? ? ? ? ? ? ? ?bingo_cnt += correct_num.sum().cpu()
? ? ? ? ? ? ? ? ? ?sample_cnt += ALL.size(0)

? ? ? ? ? ? ? ? ? ?for cls in range(3):

? ? ? ? ? ? ? ? ? ? ? ?for element in predicts:
? ? ? ? ? ? ? ? ? ? ? ? ? ?if element == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pos_label[cls] = pos_label[cls] + 1
? ? ? ? ? ? ? ? ? ? ? ?for element in label_all:
? ? ? ? ? ? ? ? ? ? ? ? ? ?if element == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pos_pred[cls] = pos_pred[cls] + 1
? ? ? ? ? ? ? ? ? ? ? ?for elementp, elementl in zip(predicts, label_all):
? ? ? ? ? ? ? ? ? ? ? ? ? ?if elementp == elementl and elementp == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?TP[cls] = TP[cls] + 1

? ? ? ? ? ? ? ? ? ?count = 0
? ? ? ? ? ? ? ? ? ?SUM_F1 = 0
? ? ? ? ? ? ? ? ? ?for index in range(3):
? ? ? ? ? ? ? ? ? ? ? ?if pos_label[index] != 0 or pos_pred[index] != 0:
? ? ? ? ? ? ? ? ? ? ? ? ? ?count = count + 1
? ? ? ? ? ? ? ? ? ? ? ? ? ?SUM_F1 = SUM_F1 + 2 * TP[index] / (pos_pred[index] + pos_label[index])

? ? ? ? ? ? ? ? ? ?AVG_F1 = SUM_F1 / count


? ? ? ? ? ? ? ?running_loss = running_loss / iter_cnt
? ? ? ? ? ? ? ?acc = bingo_cnt.float() / float(sample_cnt)
? ? ? ? ? ? ? ?acc = np.around(acc.numpy(), 4)
? ? ? ? ? ? ? ?if bingo_cnt > max_corr:
? ? ? ? ? ? ? ? ? ?max_corr = bingo_cnt
? ? ? ? ? ? ? ?if AVG_F1 >= max_f1:
? ? ? ? ? ? ? ? ? ?max_f1 = AVG_F1
? ? ? ? ? ? ? ? ? ?max_pos_label = pos_label
? ? ? ? ? ? ? ? ? ?max_pos_pred = pos_pred
? ? ? ? ? ? ? ? ? ?max_TP = TP
? ? ? ? ? ? ? ?print("[Epoch %d] Validation accuracy:%.4f. Loss:%.3f, F1-score:%.3f" % (i, acc, running_loss, AVG_F1))
? ? ? ?num_sum = num_sum + max_corr
? ? ? ?pos_label_ALL = pos_label_ALL + max_pos_label
? ? ? ?pos_pred_ALL = pos_pred_ALL + max_pos_pred
? ? ? ?TP_ALL = TP_ALL + max_TP
? ? ? ?count = 0
? ? ? ?SUM_F1 = 0
? ? ? ?for index in range(3):
? ? ? ? ? ?if pos_label_ALL[index] != 0 or pos_pred_ALL[index] != 0:
? ? ? ? ? ? ? ?count = count + 1
? ? ? ? ? ? ? ?SUM_F1 = SUM_F1 + 2 * TP_ALL[index] / (pos_pred_ALL[index] + pos_label_ALL[index])

? ? ? ?F1_ALL = SUM_F1 / count
? ? ? ?val_now = val_now + val_dataset.__len__()
? ? ? ?print("[..........%s] correctnum:%d . zongshu:%d ? " % (subj, max_corr, val_dataset.__len__()))
? ? ? ?print("[ALL_corr]: %d [ALL_val]: %d" % (num_sum, val_now))
? ? ? ?print("[F1_now]: %.4f [F1_ALL]: %.4f" % (max_f1, F1_ALL))


if __name__ == "__main__":
? ?run_training()

CA_block.py


# -*- coding: utf-8 -*-

#import torch
#import torch.nn as nn
import torch
import torch.nn as nn

torch.nn

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
? ? ? ? ? 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
? ? ? ? ? 'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
? ?'resnet18': 'https://download.torch.org/models/resnet18-5c106cde.pth',
? ?'resnet34': 'https://download.torch.org/models/resnet34-333f7ec4.pth',
? ?'resnet50': 'https://download.torch.org/models/resnet50-19c8e357.pth',
? ?'resnet101': 'https://download.torch.org/models/resnet101-5d3b4d8f.pth',
? ?'resnet152': 'https://download.torch.org/models/resnet152-b121ed2d.pth',
? ?'resnext50_32x4d': 'https://download.torch.org/models/resnext50_32x4d-7cdf4587.pth',
? ?'resnext101_32x8d': 'https://download.torch.org/models/resnext101_32x8d-8ba56ff5.pth',
? ?'wide_resnet50_2': 'https://download.torch.org/models/wide_resnet50_2-95faca4d.pth',
? ?'wide_resnet101_2': 'https://download.torch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
? ?"""3x3 convolution with padding"""
? ?return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
? ? ? ? ? ? ? ? ? ? padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1, groups=1):
? ?"""1x1 convolution"""
? ?return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False,groups=groups)

##CA BLOCK
class CABlock(nn.Module):
? ?expansion = 1

? ?def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
? ? ? ? ? ? ? ? base_width=64, dilation=1, norm_layer=None):
? ? ? ?super(CABlock, self).__init__()
? ? ? ?if norm_layer is None:
? ? ? ? ? ?norm_layer = nn.BatchNorm2d
? ? ? ?# if groups != 1 or base_width != 64:
? ? ? ?# ? ? raise ValueError('BasicBlock only supports groups=1 and base_width=64')
? ? ? ?if dilation > 1:
? ? ? ? ? ?raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
? ? ? ?# Both self.conv1 and self.downsample layers downsample the input when stride != 1
? ? ? ?self.conv1 = conv3x3(inplanes, planes, stride,groups=groups)
? ? ? ?self.bn1 = norm_layer(planes)
? ? ? ?self.relu = nn.ReLU(inplace=True)
? ? ? ?self.conv2 = conv1x1(planes, planes,groups=groups)
? ? ? ?self.bn2 = norm_layer(planes)
? ? ? ?self.attn = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(2, 1, kernel_size=1, stride=1,bias=False), ?# 32*33*33
? ? ? ? ? ?nn.BatchNorm2d(1),
? ? ? ? ? ?nn.Sigmoid(),
? ? ? ?)
? ? ? ?self.downsample = downsample
? ? ? ?self.stride = stride
? ? ? ?self.planes=planes

? ?def forward(self, x):
? ? ? ?x, attn_last,if_attn =x##attn_last: downsampled attention maps from last layer as a prior knowledge
? ? ? ?identity = x

? ? ? ?out = self.conv1(x)
? ? ? ?out = self.bn1(out)

? ? ? ?out = self.relu(out)

? ? ? ?out = self.conv2(out)
? ? ? ?out = self.bn2(out)
? ? ? ?if self.downsample is not None:
? ? ? ? ? ?identity = self.downsample(identity)

? ? ? ?out = self.relu(out+identity)
? ? ? ?avg_out = torch.mean(out, dim=1, keepdim=True)
? ? ? ?max_out, _ = torch.max(out, dim=1, keepdim=True)
? ? ? ?attn = torch.cat((avg_out, max_out), dim=1)
? ? ? ?attn = self.attn(attn)
? ? ? ?if attn_last is not None:
? ? ? ? ? ?attn = attn_last * attn

? ? ? ?attn = attn.repeat(1, self.planes, 1, 1)
? ? ? ?if if_attn:
? ? ? ? ? ?out = out *attn


? ? ? ?return out,attn[:, 0, :, :].unsqueeze(1),True





class ResNet(nn.Module):

? ?def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
? ? ? ? ? ? ? ? groups=4, width_per_group=64, replace_stride_with_dilation=None,
? ? ? ? ? ? ? ? norm_layer=None):
? ? ? ?super(ResNet, self).__init__()
? ? ? ?if norm_layer is None:
? ? ? ? ? ?norm_layer = nn.BatchNorm2d
? ? ? ?self._norm_layer = norm_layer

? ? ? ?self.inplanes = 128
? ? ? ?self.dilation = 1
? ? ? ?if replace_stride_with_dilation is None:
? ? ? ? ? ?# each element in the tuple indicates if we should replace
? ? ? ? ? ?# the 2x2 stride with a dilated convolution instead
? ? ? ? ? ?replace_stride_with_dilation = [False, False, False]
? ? ? ?if len(replace_stride_with_dilation) != 3:
? ? ? ? ? ?raise ValueError("replace_stride_with_dilation should be None "
? ? ? ? ? ? ? ? ? ? ? ? ? ? "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
? ? ? ?self.groups = groups
? ? ? ?self.base_width = width_per_group
? ? ? ?self.conv1 = nn.Conv2d(90*2, self.inplanes, kernel_size=3, stride=1,padding=1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? bias=False,groups=1)
? ? ? ?self.bn1 = norm_layer(self.inplanes)
? ? ? ?self.relu = nn.ReLU(inplace=True)
? ? ? ?self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
? ? ? ?self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
? ? ? ?self.layer1 = self._make_layer(block, 128, layers[0],groups=1)
? ? ? ?self.inplanes = int(self.inplanes*1)
? ? ? ?self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[0],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)

? ? ? ?self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[1],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)

? ? ? ?self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[2],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)





? ? ? ?self.fc = nn.Linear(512* block.expansion*196, 5)
? ? ? ?self.drop = nn.Dropout(p=0.1)
? ? ? ?for m in self.modules():
? ? ? ? ? ?if isinstance(m, nn.Conv2d):
? ? ? ? ? ? ? ?nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
? ? ? ? ? ?elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
? ? ? ? ? ? ? ?nn.init.constant_(m.weight, 1)
? ? ? ? ? ? ? ?nn.init.constant_(m.bias, 0)

? ? ? ?# Zero-initialize the last BN in each residual branch,
? ? ? ?# so that the residual branch starts with zeros, and each residual block behaves like an identity.
? ? ? ?# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
? ? ? ?if zero_init_residual:
? ? ? ? ? ?for m in self.modules():
? ? ? ? ? ? ? ?if isinstance(m, Bottleneck):
? ? ? ? ? ? ? ? ? ?nn.init.constant_(m.bn3.weight, 0)
? ? ? ? ? ? ? ?elif isinstance(m, BasicBlock):
? ? ? ? ? ? ? ? ? ?nn.init.constant_(m.bn2.weight, 0)

? ?def _make_layer(self, block, planes, blocks, stride=1, dilate=False,groups=1):
? ? ? ?norm_layer = self._norm_layer
? ? ? ?downsample = None
? ? ? ?previous_dilation = self.dilation
? ? ? ?if dilate:
? ? ? ? ? ?self.dilation *= stride
? ? ? ? ? ?stride = 1
? ? ? ?if stride != 1 or self.inplanes != planes * block.expansion:
? ? ? ? ? ?downsample = nn.Sequential(
? ? ? ? ? ? ? ?conv1x1(self.inplanes, planes * block.expansion, stride),
? ? ? ? ? ? ? ?norm_layer(planes * block.expansion),
? ? ? ? ? ?)

? ? ? ?layers = []
? ? ? ?layers.append(block(self.inplanes, planes, stride, downsample, groups,
? ? ? ? ? ? ? ? ? ? ? ? ? ?self.base_width, previous_dilation, norm_layer))
? ? ? ?self.inplanes = planes * block.expansion
? ? ? ?for _ in range(1, blocks):
? ? ? ? ? ?layers.append(block(self.inplanes, planes, groups=self.groups,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?base_width=self.base_width, dilation=self.dilation,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?norm_layer=norm_layer))

? ? ? ?return nn.Sequential(*layers)

? ?def _forward_impl(self, x,POS):##x->input of main branch; POS->position embeddings generated by sub branch

? ? ? ?x = self.conv1(x)
? ? ? ?x = self.bn1(x)
? ? ? ?x = self.relu(x)
? ? ? ?##main branch
? ? ? ?x,attn1,_ = self.layer1((x,None,True))
? ? ? ?temp = attn1
? ? ? ?attn1 = self.maxpool(attn1)

? ? ? ?x ,attn2,_= self.layer2((x,attn1,True))


? ? ? ?attn2=self.maxpool(attn2)

? ? ? ?x ,attn3,_= self.layer3((x,attn2,True))
? ? ? ?#
? ? ? ?attn3 = self.maxpool(attn3)
? ? ? ?x,attn4,_ = self.layer4((x,attn3,True))

? ? ? ?x=x+POS#fusion of motion pattern feature and position embeddings

? ? ? ?x = torch.flatten(x, 1)

? ? ? ?x = self.fc(x)

? ? ? ?return x,temp.view(x.size(0),-1)

? ?def forward(self, x,POS):
? ? ? ?return self._forward_impl(x,POS)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
? ?model = ResNet(block, layers, **kwargs)
? ?if pretrained:
? ? ? ?state_dict = load_state_dict_from_url(model_urls[arch],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?progress=progress)
? ? ? ?model.load_state_dict(state_dict)
? ?return model

##main branch consisting of CA blocks
def resnet18_pos_attention(pretrained=False, progress=True, **kwargs):
? ?r"""ResNet-18 model from
? ?`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

? ?Args:
? ? ? ?pretrained (bool): If True, returns a model pre-trained on ImageNet
? ? ? ?progress (bool): If True, displays a progress bar of the download to stderr
? ?"""
? ?return _resnet('resnet18', CABlock, [1, 1, 1, 1], pretrained, progress,
? ? ? ? ? ? ? ? ? **kwargs)


PC_module.py

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# -*- coding: utf-8 -*-
#import torch
#import torch.nn as nn
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
import math
import logging
from functools import partial
from collections import OrderedDict

#import torch
#import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.functional as F
from itertools import repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import collections.abc
def drop_path(x, drop_prob: float = 0., training: bool = False):
? ?"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

? ?This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
? ?the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
? ?See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
? ?changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
? ?'survival rate' as the argument.

? ?"""
? ?if drop_prob == 0. or not training:
? ? ? ?return x
? ?keep_prob = 1 - drop_prob
? ?shape = (x.shape[0],) + (1,) * (x.ndim - 1) ?# work with diff dim tensors, not just 2D ConvNets
? ?random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
? ?random_tensor.floor_() ?# binarize
? ?output = x.div(keep_prob) * random_tensor
? ?return output


class DropPath(nn.Module):
? ?"""Drop paths (Stochastic Depth) per sample ?(when applied in main path of residual blocks).
? ?"""
? ?def __init__(self, drop_prob=None):
? ? ? ?super(DropPath, self).__init__()
? ? ? ?self.drop_prob = drop_prob

? ?def forward(self, x):
? ? ? ?return drop_path(x, self.drop_prob, self.training)
def _ntuple(n):
? ?def parse(x):
? ? ? ?if isinstance(x, collections.abc.Iterable):
? ? ? ? ? ?return x
? ? ? ?return tuple(repeat(x, n))
? ?return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
__all__ = [
? ?'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
? ?'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
? ?'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
? ?'deit_base_distilled_patch16_384',
]

class Mlp(nn.Module):
? ?def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
? ? ? ?super().__init__()
? ? ? ?out_features = out_features or in_features
? ? ? ?hidden_features = hidden_features or in_features
? ? ? ?self.fc1 = nn.Linear(in_features, hidden_features)
? ? ? ?self.act = act_layer()
? ? ? ?self.fc2 = nn.Linear(hidden_features, out_features)
? ? ? ?self.drop = nn.Dropout(drop)

? ?def forward(self, x):
? ? ? ?x = self.fc1(x)
? ? ? ?x = self.act(x)
? ? ? ?x = self.drop(x)
? ? ? ?x = self.fc2(x)
? ? ? ?x = self.drop(x)
? ? ? ?return x


class Attention(nn.Module):
? ?def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
? ? ? ?super().__init__()
? ? ? ?self.num_heads = num_heads
? ? ? ?head_dim = dim // num_heads
? ? ? ?# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
? ? ? ?self.scale = qk_scale or head_dim ** -0.5

? ? ? ?self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
? ? ? ?self.attn_drop = nn.Dropout(attn_drop)
? ? ? ?self.proj = nn.Linear(dim, dim)
? ? ? ?self.proj_drop = nn.Dropout(proj_drop)

? ?def forward(self, x):
? ? ? ?B, N, C = x.shape
? ? ? ?qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
? ? ? ?q, k, v = qkv[0], qkv[1], qkv[2] ? # make torchscript happy (cannot use tensor as tuple)
? ? ? ?varq = torch.var(q, dim=2).sum(dim=2).sum()/B/N
? ? ? ?vark = torch.var(k, dim=2).sum(dim=2).sum()/B/N
? ? ? ?varv = torch.var(v, dim=2).sum(dim=2).sum()/B/N
? ? ? ?attn = (q @ k.transpose(-2, -1)) * self.scale
? ? ? ?attn = attn.softmax(dim=-1)
? ? ? ?attn = self.attn_drop(attn)

? ? ? ?x = (attn @ v).transpose(1, 2).reshape(B, N, C)
? ? ? ?x = self.proj(x)
? ? ? ?x = self.proj_drop(x)
? ? ? ?return x


class Block(nn.Module):

? ?def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
? ? ? ? ? ? ? ? drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
? ? ? ?super().__init__()
? ? ? ?self.norm1 = norm_layer(dim)
? ? ? ?self.attn = Attention(
? ? ? ? ? ?dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
? ? ? ?# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
? ? ? ?self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
? ? ? ?self.norm2 = norm_layer(dim)
? ? ? ?mlp_hidden_dim = int(dim * mlp_ratio)
? ? ? ?self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

? ?def forward(self, x):
? ? ? ?x = x + self.drop_path(self.attn(self.norm1(x)))
? ? ? ?x = x + self.drop_path(self.mlp(self.norm2(x)))
? ? ? ?return x


class PatchEmbed(nn.Module):
? ?""" Image to Patch Embedding
? ?"""
? ?def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
? ? ? ?super().__init__()
? ? ? ?img_size = to_2tuple(img_size)
? ? ? ?patch_size = to_2tuple(patch_size)
? ? ? ?num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
? ? ? ?self.img_size = img_size
? ? ? ?self.patch_size = patch_size
? ? ? ?self.num_patches = num_patches

? ? ? ?self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

? ?def forward(self, x):
? ? ? ?B, C, H, W = x.shape
? ? ? ?# FIXME look at relaxing size constraints
? ? ? ?assert H == self.img_size[0] and W == self.img_size[1], \
? ? ? ? ? ?f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
? ? ? ?x = self.proj(x).flatten(2).transpose(1, 2)
? ? ? ?return x


class HybridEmbed(nn.Module):
? ?""" CNN Feature Map Embedding
? ?Extract feature map from CNN, flatten, project to embedding dim.
? ?"""
? ?def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
? ? ? ?super().__init__()
? ? ? ?assert isinstance(backbone, nn.Module)
? ? ? ?img_size = to_2tuple(img_size)
? ? ? ?self.img_size = img_size
? ? ? ?self.backbone = backbone
? ? ? ?if feature_size is None:
? ? ? ? ? ?with torch.no_grad():
? ? ? ? ? ? ? ?# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
? ? ? ? ? ? ? ?# map for all networks, the feature metadata has reliable channel and stride info, but using
? ? ? ? ? ? ? ?# stride to calc feature dim requires info about padding of each stage that isn't captured.
? ? ? ? ? ? ? ?training = backbone.training
? ? ? ? ? ? ? ?if training:
? ? ? ? ? ? ? ? ? ?backbone.eval()
? ? ? ? ? ? ? ?o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
? ? ? ? ? ? ? ?if isinstance(o, (list, tuple)):
? ? ? ? ? ? ? ? ? ?o = o[-1] ?# last feature if backbone outputs list/tuple of features
? ? ? ? ? ? ? ?feature_size = o.shape[-2:]
? ? ? ? ? ? ? ?feature_dim = o.shape[1]
? ? ? ? ? ? ? ?backbone.train(training)
? ? ? ?else:
? ? ? ? ? ?feature_size = to_2tuple(feature_size)
? ? ? ? ? ?if hasattr(self.backbone, 'feature_info'):
? ? ? ? ? ? ? ?feature_dim = self.backbone.feature_info.channels()[-1]
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?feature_dim = self.backbone.num_features
? ? ? ?self.num_patches = feature_size[0] * feature_size[1]
? ? ? ?self.proj = nn.Conv2d(feature_dim, embed_dim, 1)

? ?def forward(self, x):
? ? ? ?x = self.backbone(x)
? ? ? ?if isinstance(x, (list, tuple)):
? ? ? ? ? ?x = x[-1] ?# last feature if backbone outputs list/tuple of features
? ? ? ?x = self.proj(x).flatten(2).transpose(1, 2)
? ? ? ?return x


###Position Calibration Module
class VisionTransformer_POS(nn.Module):
? ?""" Vision Transformer

? ?A torch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` ?-
? ? ? ?https://arxiv.org/abs/2010.11929
? ?"""
? ?def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
? ? ? ? ? ? ? ? num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
? ? ? ? ? ? ? ? drop_rate=0., attn_drop_rate=0., drop_path_rate=0.15, hybrid_backbone=None, norm_layer=None):
? ? ? ?"""
? ? ? ?Args:
? ? ? ? ? ?img_size (int, tuple): input image size
? ? ? ? ? ?patch_size (int, tuple): patch size
? ? ? ? ? ?in_chans (int): number of input channels
? ? ? ? ? ?num_classes (int): number of classes for classification head
? ? ? ? ? ?embed_dim (int): embedding dimension
? ? ? ? ? ?depth (int): depth of transformer
? ? ? ? ? ?num_heads (int): number of attention heads
? ? ? ? ? ?mlp_ratio (int): ratio of mlp hidden dim to embedding dim
? ? ? ? ? ?qkv_bias (bool): enable bias for qkv if True
? ? ? ? ? ?qk_scale (float): override default qk scale of head_dim ** -0.5 if set
? ? ? ? ? ?representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
? ? ? ? ? ?drop_rate (float): dropout rate
? ? ? ? ? ?attn_drop_rate (float): attention dropout rate
? ? ? ? ? ?drop_path_rate (float): stochastic depth rate
? ? ? ? ? ?hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
? ? ? ? ? ?norm_layer: (nn.Module): normalization layer
? ? ? ?"""
? ? ? ?super().__init__()
? ? ? ?norm_layer=partial(nn.LayerNorm, eps=1e-6)
? ? ? ?self.num_classes = num_classes
? ? ? ?self.num_features = self.embed_dim = embed_dim ?# num_features for consistency with other models
? ? ? ?norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

? ? ? ?if hybrid_backbone is not None:
? ? ? ? ? ?self.patch_embed = HybridEmbed(
? ? ? ? ? ? ? ?hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
? ? ? ?else:
? ? ? ? ? ?self.patch_embed = PatchEmbed(
? ? ? ? ? ? ? ?img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
? ? ? ?num_patches = self.patch_embed.num_patches

? ? ? ?self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
? ? ? ?self.pos_embed = nn.Parameter(torch.zeros(1, 196, embed_dim))
? ? ? ?self.pos_drop = nn.Dropout(p=drop_rate)

? ? ? ?dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] ?# stochastic depth decay rule
? ? ? ?self.blocks = nn.ModuleList([
? ? ? ? ? ?Block(
? ? ? ? ? ? ? ?dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ?drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
? ? ? ? ? ?for i in range(depth)])
? ? ? ?self.norm = norm_layer(embed_dim)

? ? ? ?# Representation layer
? ? ? ?if representation_size:
? ? ? ? ? ?self.num_features = representation_size
? ? ? ? ? ?self.pre_logits = nn.Sequential(OrderedDict([
? ? ? ? ? ? ? ?('fc', nn.Linear(embed_dim, representation_size)),
? ? ? ? ? ? ? ?('act', nn.Tanh())
? ? ? ? ? ?]))
? ? ? ?else:
? ? ? ? ? ?self.pre_logits = nn.Identity()

? ? ? ?# Classifier head
? ? ? ?self.head = nn.Linear(self.num_features, 5) if num_classes > 0 else nn.Identity()
? ? ? ?# self.to_Mask = nn.Sequential(nn.Conv2d(in_channels=self.num_features,out_channels=1,kernel_size=3,padding=1),
? ? ? ?# ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?nn.Hardsigmoid(),
? ? ? ?# ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?)
? ? ? ?# self.to_Mask = nn.Linear(self.num_features,1)
? ? ? ?self.to_Mask = nn.Sequential(nn.Linear(self.num_features,1),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? )
? ? ? ?trunc_normal_(self.pos_embed, std=.02)
? ? ? ?trunc_normal_(self.cls_token, std=.02)
? ? ? ?self.apply(self._init_weights)

? ?def _init_weights(self, m):
? ? ? ?if isinstance(m, nn.Linear):
? ? ? ? ? ?trunc_normal_(m.weight, std=.02)
? ? ? ? ? ?if isinstance(m, nn.Linear) and m.bias is not None:
? ? ? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ?elif isinstance(m, nn.LayerNorm):
? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ? ? ?nn.init.constant_(m.weight, 1.0)

? ?@torch.jit.ignore
? ?def no_weight_decay(self):
? ? ? ?return {'pos_embed', 'cls_token'}

? ?def get_classifier(self):
? ? ? ?return self.head

? ?def reset_classifier(self, num_classes, global_pool=''):
? ? ? ?self.num_classes = num_classes
? ? ? ?self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

? ?def forward_features(self, x):
? ? ? ?B = x.shape[0]
? ? ? ?x = self.patch_embed(x)


? ? ? ?x = x + self.pos_embed
? ? ? ?x = self.pos_drop(x)


? ? ? ?for blk in self.blocks:
? ? ? ? ? ?x = blk(x)


? ? ? ?x = self.norm(x)
? ? ? ?x = self.pre_logits(x)
? ? ? ?return x

? ?def forward(self, x):
? ? ? ?x = self.forward_features(x)

? ? ? ?return x

CASME2數(shù)據(jù)集中的問題請自行修改



MMNET 微表情識別(CASME2數(shù)據(jù)集)的評論 (共 條)

分享到微博請遵守國家法律
介休市| 丹凤县| 远安县| 福贡县| 内江市| 永康市| 蒙山县| 醴陵市| 新宁县| 绿春县| 东兰县| 澄迈县| 安达市| 遂昌县| 高淳县| 霍山县| 济阳县| 浠水县| 长丰县| 东山县| 洪江市| 合阳县| 湖南省| 磴口县| 瑞丽市| 华安县| 竹山县| 迁安市| 佛冈县| 兴安盟| 花垣县| 合川市| 凤台县| 垣曲县| 乌拉特后旗| 绥化市| 葫芦岛市| 姚安县| 高平市| 莱芜市| 加查县|