[Home]
G2DLDA
A Matlab code for Generalized Lp-norm two-dimensional linear
discriminant analysis with regularization.
(You could Right-Click [Code] , and Save, then you can download the whole matlab code.)
Reference
Chun-Na Li, Yuan-Hai Shao, Wei-Jie Chen,Nai-Yang Deng "Generalized two-dimensional linear discriminant analysis
with regularization" Submitted 2018.
[Slides]
Main Function
need function SLp2DLAD
classdef SLp2DLAD < LearningAlgorithm
% % % % Examples
% close all; clear variables;
% Useage:
% Input - TrData: the trainig data
% sigma: the regularization parameter.
% p: the Lp-norm
% Output - Model.W: the projection matrix
%
% Usage:
% for i = 1:50
% TrData.X(:,:,i) = rand(32,32);
% end
% Y = [ones(25,1);-ones(25,1)];
% Model = SLp2DLAD('SLp2DLAD',0.01,1.5);
% Model = Model.train(TrData);
% Poj = Model.W;
%
%
% Reference:
% Chun-Na Li, Yuan-Hai Shao,Wei-Jie Chen, and Nai-Yang Deng, "Generalized two-dimensional linear discriminant analysis with
% regularization", submitted 2018
% Version 1.0 --8.Jan/2018
%
% Written by Wei-Jie Chen, wjcper2008@126.com
properties
sigma = 1;
p = 1;
W;
m_Cls;
MaxIter = 100;
nCls;
idxCls;
d1;
d2;
m;
end
methods
function obj = SLp2DLAD(name, sigma, p)
obj = obj@LearningAlgorithm(name);
obj.sigma = sigma;
obj.p = p;
end
function obj = train(obj,Data)
X = Data.X;
Y = Data.Y;
[obj.d1,obj.d2,obj.m] = size(X);
label = unique(Y);
obj.nCls = length(label);
obj.m_Cls = zeros(obj.nCls,1);
obj.idxCls = cell(obj.nCls,1);
for k = 1:obj.nCls
obj.idxCls{k} = find(Y==label(k));
obj.m_Cls(k) = length(obj.idxCls{k});
end
obj.W = [];
I = eye(obj.d1);
TrainX = X;
B = I;
for d = 1:obj.d1
w = obj.update_w(TrainX);
obj.W = [obj.W B*w];
B = null(obj.W');
TrainX = zeros(obj.d1-d,obj.d2,obj.m);
for i = 1:obj.m
TrainX(:,:,i) = B'*X(:,:,i);
end
fprintf('RD for %d dim\n', d);
end
end
function w = update_w(obj,X)
dd1 = size(X,1);
M = mean(X,3);
MCls = zeros(dd1,obj.d2,obj.nCls);
for k = 1:obj.nCls
MCls(:,:,k) = mean(X(:,:,obj.idxCls{k}),3);
end
V = zeros(dd1,obj.d2,obj.nCls);
Z = zeros(dd1,obj.d2,obj.m);
for k = 1:obj.nCls
V(:,:,k) = MCls(:,:,k) - M;
Z(:,:,obj.idxCls{k}) = X(:,:,obj.idxCls{k}) - MCls(:,:,k);
end
w = ones(dd1,1);
w = w/norm(w);
for iter=1:obj.MaxIter
w_old = w;
H1 = zeros(dd1,dd1);
for d = 1:obj.d2
Z_d = permute(Z(:,d,:),[1 3 2]);
H1 = H1 + (Z_d./(abs(w'*Z_d).^(2-obj.p) + eps))*Z_d';
end
H2 = obj.sigma * diag(1./(abs(w).^(2-obj.p) + eps));
H = H1 + H2;
h = zeros(dd1,1);
for d=1:obj.d2
V_d = permute(V(:,d,:),[1 3 2]);
wV_d = V_d'*w;
h = h + V_d * (obj.m_Cls.*(abs(wV_d).^(1-obj.p)).*sign(wV_d));
end
Hivh = H\h;
w = Hivh/(h'*Hivh);
w = w/norm(w);
if norm(w_old - w) < 10^(-6)
break
end
end
end
end
end
Any question or advice please email to na1013na@163.com or wjcper2008@126.com
- Last updated: Jan 8, 2018