[Home]
2DLDAL1S
A Matlab code for Sparse L1-norm two dimensional linear discriminant
analysis via the generalized elastic net regularization.
(You could Right-Click [Code] , and Save, then you can download the whole matlab code.)
Reference
Chun-Na Li, Meng-Qi Shang, Yuan-Hai Shao*, Zhen Wang,Nai-Yang Deng "Sparse L1-norm two dimensional linear discriminant
analysis via the generalized elastic net regularization" Submitted 2018.
Main Function
function [V,X] = S2DLDAL1(X,Y,v0,sigma,delta,p,dim)
% %
% close all; clear variables;
% Useage:
% Input - X: the trainig data, is a 3-dimensional data of size d1*d2*N
% Y: the label vector corresponding to X
% v0: the initialization projection vector
% sigma: the L2-norm regularization term parameter
% delta: the Lp-norm regularization term parameter
% p: in the Lp-norm
% dim: the reduced dimension
% Output - V: the projection matrix
%
% % % % Examples
% for i = 1:10
% X(:,:,i) = rand(32,32);
% end
% Y = [ones(5,1);-ones(5,1)];
% v0 = ones(32,1);
% sigma = 10^4;
% delta = 10^-3
% p = 1;
% dim = 10;
% [V,X] = S2DLDAL1(X,Y,v0,sigma,delta,p,dim);
%
%
% Reference:
% Chun-Na Li, Meng-Qi Shang, Yuan-Hai Shao, and Nai-Yang Deng,
% "Sparse L1-norm two dimensional linear discriminant analysis via the generalized elastic net regularization", submitted 2018
% Version 1.0 -- 15.April/2018
%
% Written by Meng-Qi Shang,17858527466@163.com
%输入:数据矩阵,标签数据,初始值w,s投影纬度,每个图像矩阵参数行和列
%输出投影矩阵,与更新后的原始数据
A = X;
V = zeros(size(X,2),dim);
d=1;
while d <= dim
[v1] = main2(X,Y,v0,sigma,delta,p);
V(:,d) = v1;
X = updata(A,V);
clc
fprintf('The %d-the dimension is done\n',d)
d = d+1;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [v] = main2(X,Y,v,sigma,delta,p)
%输入:训练数据集,初始w,每个数据矩阵的行列
%输出:收敛的w
maxiter = 0;
v0 = 0;
while max(abs(v-v0))>0.0001 && maxiter<100%检测目标值是否收敛合格
v0 = v;
new_v = main1(X,Y,v,sigma,delta,p);
v = new_v;
maxiter = maxiter+1;
end
v = v./sqrt(v'*v);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [v]=main1(X,lable,v0,sigma,delta,p)
class = unique(lable);
c = length(class);
matrix1 = mean(X,3);%总均值矩阵
[m,n,~] = size(X);
h = 0;H = zeros(n);
for i=1:c
[data1,matrix2,num] = find_matrix(X,lable,class(i));
Y = matrix2 - matrix1;
h = h + sum(num*(ones(n,1)*sign(Y*v0)').*Y',2);
for j=1:num
matrix4 = data1(:,:,j);
for k=1:m
Z = matrix4(k,:)-matrix2(k,:);
if abs(Z*v0) <= 10^-4
v0 = v0+0.0002*ones(n,1);
H = H+(Z'*Z/abs(Z*v0));
else
H = H+(Z'*Z/abs(Z*v0));
end
end
end
end
H = H+sigma*eye(n);
G = delta*((abs(v0)).^(p-1).*sign(v0));
v = ((h'/H*G+1)*H\h)/(h'/H*h)-H\G;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [sample_data,mean_data,num]=find_matrix(X,Y,C)
%输入:X样本数据,Y为样本标签,C为类别
%输出:x1为第c类样本数据,x2为第c类样本均值,num为c类样本数量
sample_data = X(:,:,Y == C);
mean_data = mean(sample_data,3);
num = length(sample_data(1,1,:));
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [newX] = updata(X,V)
%更新原始数据
[~,m,n]=size(X);
newX = zeros(size(X));
for i=1:n
newX(:,:,i) = X(:,:,i)*(eye(m)-V*V');
end
end
Any question or advice please email to na1013na@163.com
- Last updated: Apr 15, 2018