[Home]


G2DLDA

A Matlab code for Generalized Lp-norm two-dimensional linear discriminant analysis with regularization. (You could Right-Click [Code] , and Save, then you can download the whole matlab code.)


Reference

Chun-Na Li, Yuan-Hai Shao, Wei-Jie Chen,Nai-Yang Deng "Generalized two-dimensional linear discriminant analysis with regularization" Submitted 2018.

[Slides]


Main Function

need function SLp2DLAD classdef SLp2DLAD < LearningAlgorithm % % % % Examples % close all; clear variables; % Useage: % Input - TrData: the trainig data % sigma: the regularization parameter. % p: the Lp-norm % Output - Model.W: the projection matrix % % Usage: % for i = 1:50 % TrData.X(:,:,i) = rand(32,32); % end % Y = [ones(25,1);-ones(25,1)]; % Model = SLp2DLAD('SLp2DLAD',0.01,1.5); % Model = Model.train(TrData); % Poj = Model.W; % % % Reference: % Chun-Na Li, Yuan-Hai Shao,Wei-Jie Chen, and Nai-Yang Deng, "Generalized two-dimensional linear discriminant analysis with % regularization", submitted 2018 % Version 1.0 --8.Jan/2018 % % Written by Wei-Jie Chen, wjcper2008@126.com properties sigma = 1; p = 1; W; m_Cls; MaxIter = 100; nCls; idxCls; d1; d2; m; end methods function obj = SLp2DLAD(name, sigma, p) obj = obj@LearningAlgorithm(name); obj.sigma = sigma; obj.p = p; end function obj = train(obj,Data) X = Data.X; Y = Data.Y; [obj.d1,obj.d2,obj.m] = size(X); label = unique(Y); obj.nCls = length(label); obj.m_Cls = zeros(obj.nCls,1); obj.idxCls = cell(obj.nCls,1); for k = 1:obj.nCls obj.idxCls{k} = find(Y==label(k)); obj.m_Cls(k) = length(obj.idxCls{k}); end obj.W = []; I = eye(obj.d1); TrainX = X; B = I; for d = 1:obj.d1 w = obj.update_w(TrainX); obj.W = [obj.W B*w]; B = null(obj.W'); TrainX = zeros(obj.d1-d,obj.d2,obj.m); for i = 1:obj.m TrainX(:,:,i) = B'*X(:,:,i); end fprintf('RD for %d dim\n', d); end end function w = update_w(obj,X) dd1 = size(X,1); M = mean(X,3); MCls = zeros(dd1,obj.d2,obj.nCls); for k = 1:obj.nCls MCls(:,:,k) = mean(X(:,:,obj.idxCls{k}),3); end V = zeros(dd1,obj.d2,obj.nCls); Z = zeros(dd1,obj.d2,obj.m); for k = 1:obj.nCls V(:,:,k) = MCls(:,:,k) - M; Z(:,:,obj.idxCls{k}) = X(:,:,obj.idxCls{k}) - MCls(:,:,k); end w = ones(dd1,1); w = w/norm(w); for iter=1:obj.MaxIter w_old = w; H1 = zeros(dd1,dd1); for d = 1:obj.d2 Z_d = permute(Z(:,d,:),[1 3 2]); H1 = H1 + (Z_d./(abs(w'*Z_d).^(2-obj.p) + eps))*Z_d'; end H2 = obj.sigma * diag(1./(abs(w).^(2-obj.p) + eps)); H = H1 + H2; h = zeros(dd1,1); for d=1:obj.d2 V_d = permute(V(:,d,:),[1 3 2]); wV_d = V_d'*w; h = h + V_d * (obj.m_Cls.*(abs(wV_d).^(1-obj.p)).*sign(wV_d)); end Hivh = H\h; w = Hivh/(h'*Hivh); w = w/norm(w); if norm(w_old - w) < 10^(-6) break end end end end end
Contacts


Any question or advice please email to na1013na@163.com or wjcper2008@126.com