MATLAB CoderでMEX化

MATLABは行列演算処理は高速なのですが、forループで書くと演算処理が極端に低下します。 for ループを多用するコードではC言語やFortranでコードを書き、MEX化することで演算処理を速くすることになります。 別売りのMATLAB Coderを使うと、関数MファイルをMEXファイル化してくれます。

Contents

三次元グリッドデータの線形補間

MATLABには三次元グリッドデータの線形補間の関数としてinterp3.mがあります。 生憎この関数はMATLAB CoderでMEX化することはできませんでした。 そこで別途hns_interp3.mという関数Mファイルを作成しました。

functionからreturnまでの%を削除してreturnまでコピーしてhns_interp3.mという名前で保存してください。 引数Dはuint8、x,y,zはdoubleのベクトルです。 C言語では型や配列数が厳密に定義されますが、MATLABは基本doubleで配列数はかなり自由がききます。 MEX化する際には関数Mファイルの変数はassertを使って定義しておく必要があります。

% function D2=hns_interp3(D,x,y,z)
% MAX=1024;% 最大配列数
% assert(isa(D,'uint8'));% 三次元配列のクラス int16,int32など
% assert(all(size(D)<=[MAX,MAX,MAX]));
% assert(isa(x,'double'));
% assert(all(size(x)<=[1,MAX]));
% assert(isa(y,'double'));
% assert(all(size(y)<=[1,MAX]));
% assert(isa(z,'double'));
% assert(all(size(z)<=[1,MAX]));
% nz=length(z);
% ny=length(y);
% nx=length(x);
% [mx,my,mz]=size(D);
% D2=repmat(uint8(0),[nx,ny,nz]);% 三次元配列のクラス int16,int32など
%
% for zz=1:nz
%     z0=floor(z(zz));
%     if z0<1;z0=1;end;
%     z1=z0+1;
%     if z1>mz;z1=mz;end;
%     dz=z(zz)-z0;
%     zd=1-dz;
%     for yy=1:ny
%         y0=floor(y(yy));
%         if y0<1;y0=1;end;
%         y1=y0+1;
%         if y1>my;y1=my;end;
%         dy=y(yy)-y0;
%         yd=1-dy;
%         for xx=1:nx
%             x0=floor(x(xx));
%             if x0<1;x0=1;end;
%             x1=x0+1;
%             if x1>mx;x1=mx;end;
%             dx=x(xx)-x0;
%             xd=1-dx;
%             k1=zd*yd*xd*double(D(x0,y0,z0));
%             k2=zd*yd*dx*double(D(x1,y0,z0));
%             k3=zd*dy*xd*double(D(x0,y1,z0));
%             k4=zd*dy*dx*double(D(x1,y1,z0));
%             k5=dz*yd*xd*double(D(x0,y0,z1));
%             k6=dz*yd*dx*double(D(x1,y0,z1));
%             k7=dz*dy*xd*double(D(x0,y1,z1));
%             k8=dz*dy*dx*double(D(x1,y1,z1));
%             D2(xx,yy,zz)=uint8(k1+k2+k3+k4+k5+k6+k7+k8);
%         end;
%     end;
% end;
% return;

MATLAB Coderのcodegen関数を用いてMEXファイルを作成

hns_interp3_mex.拡張子という名前で保存されます。拡張子は64bit Windowsはmexw64、 32bit Windowsはmexw32、Macはmaci64、Linuxはglnxa64です。

codegen('hns_interp3.m');

C++言語で書き直し

hns_interp3.mと同じものをC++言語で書き直しました。 以下のファイルを% %を削除してhns_interp3_uint8.cppという名前で保存してください。

% % #include "mex.h"
% % #include "matrix.h"
% % #include<math.h>
% %
% % void mexFunction(int nlhs,mxArray *plhs[],int nrhs, const mxArray *prhs[]){
% %     unsigned char *D;//引数   int16の三次元配列
% %     unsigned char *R;//返り値  uint8の三次元配列
% %     const int *Dsize;//Dの大きさ
% %     double *xx,*yy,*zz;
% %     D=(unsigned char*)mxGetData(prhs[0]);
% %     Dsize=mxGetDimensions(prhs[0]);
% %     int mx=*Dsize,my=*(Dsize+1),mz=*(Dsize+2);
% %     xx=mxGetPr(prhs[1]);
% %     int nx=(int)mxGetNumberOfElements(prhs[1]);
% %     yy=mxGetPr(prhs[2]);
% %     int ny=(int)mxGetNumberOfElements(prhs[2]);
% %     zz=mxGetPr(prhs[3]);
% %     int nz=(int)mxGetNumberOfElements(prhs[3]);
% %     int nn[3]={nx,ny,nz};
% %     plhs[0]=mxCreateNumericArray(3,nn,mxUINT8_CLASS,mxREAL);
% %     R=(unsigned char*)mxGetData(plhs[0]);
% %     int x,y,z,mxy=mx*my;
% %     int z0,z1,y0,y1,x0,x1;
% %     int y0mx,y1mx,z0mxy,z1mxy;
% %     double zd,dz,yd,dy,xd,dx;
% %     double k[8];
% %     //mexPrintf("%d %d %d\n",mx,my,mz);
% %     for(z=0;z<nz;++z){
% %         z0=floor(*(zz+z))-1;
% %         if(z0<0){z0=0;}
% %         z1=z0+1;
% %         if(z1>mz-1){z1=mz-1;}
% %         dz=*(zz+z)-(double)z0-1.0;
% %         zd=1.0-dz;
% %         z0mxy=z0*mxy;
% %         z1mxy=z1*mxy;
% %         for(y=0;y<ny;++y){
% %             y0=floor(*(yy+y))-1;
% %             if(y0<0){y0=0;}
% %             y1=y0+1;
% %             if(y1>my-1){y1=my-1;}
% %             dy=*(yy+y)-(double)y0-1.0;
% %             yd=1.0-dy;
% %             y0mx=y0*mx;
% %             y1mx=y1*mx;
% %             for(x=0;x<nx;++x){
% %                 x0=floor(*(xx+x))-1;
% %                 if(x0<0){x0=0;}
% %                 x1=x0+1;
% %                 if(x1>mx-x){x1=mx-1;}
% %                 dx=*(xx+x)-(double)x0-1.0;
% %                 xd=1.0-dx;
% %                 k[0]=zd*yd*xd*(double)*(D+x0+y0mx+z0mxy);
% %                 k[1]=zd*yd*dx*(double)*(D+x1+y0mx+z0mxy);
% %                 k[2]=zd*dy*xd*(double)*(D+x0+y1mx+z0mxy);
% %                 k[3]=zd*dy*dx*(double)*(D+x1+y1mx+z0mxy);
% %                 k[4]=dz*yd*xd*(double)*(D+x0+y0mx+z1mxy);
% %                 k[5]=dz*yd*dx*(double)*(D+x1+y0mx+z1mxy);
% %                 k[6]=dz*dy*xd*(double)*(D+x0+y1mx+z1mxy);
% %                 k[7]=dz*dy*dx*(double)*(D+x1+y1mx+z1mxy);
% %                 *R=(unsigned char)(k[0]+k[1]+k[2]+k[3]+k[4]+k[5]+k[6]+k[7]);
% %                 R++;
% %             }
% %         }
% %     }
% %
% % }

MATLAB本体のmex関数を使ってMEX化

hns_interp3_uint8.拡張子という名前で保存されます。拡張子は64bit Windowsはmexw64、 32bit Windowsはmexw32、Macはmaci64、Linuxはglnxa64です。

mex('hns_interp3_uint8.cpp');

mriファイルの読み込み

MATLABに頭部MRIのデータがあります。これを読み込みます。

load mri;
D=squeeze(D);

三次元線形補間とMRIの表示

読み込んだmriのファイルのvoxelの大きさは[2,2,5]mmです。線形補間して0.5mmの立方体にします。 MATLABの関数interp3.mと関数Mファイルhns_interp3.mとMEX化したhns_interp3_mexを比較します。 MEX化することで処理速度が高速化されています。 Windows版だと残念ながらMEX化してもMATLABの関数interp3.mの処理速度には適いませんでしたが、 メモリ使用量はinterp3.mがdoubleで計算するため512x512x270x8(double)x4(変数)=2160MBですが、 hns_interp3_mexだと512x512x270x1(uint8)x1(変数)=67.5MBと大幅に改善されています。

尚、文字列の半角¥マークはアンダーバーを表示するための記号です。

figure('color',[1,1,1]);
for n=1:3;
    subplot(1,3,n);
    switch n
        case 1;imagesc(squeeze(D(round(siz(1)/2),:,:)));
        case 2;imagesc(squeeze(D(:,round(siz(2)/2),:)));
              title('線形補間前');
        case 3;imagesc(squeeze(D(:,:,round(siz(3)/2))));
    end;
    daspect([1,1,1]);
end;
colormap(gray);

siz2=siz.*[2,2,5];% voxel sizeを[2,2,5]mmとしました。
siz2=siz2*2;% 0.5mm voxel
x=linspace(1,siz(1),siz2(1));
y=linspace(1,siz(2),siz2(2));
z=linspace(1,siz(3),siz2(3));

% MATLABの関数interp3.mを使用
tic;
[xx,yy,zz]=meshgrid(x,y,z);
D1=interp3(double(D),xx,yy,zz);
D1=uint8(D1);
st1=toc;
% hns_interp3.mを使用
tic;
D2=hns_interp3(D,x,y,z);
st2=toc;

% MEX化したhns_interp3_mexを使用
tic;
D3=hns_interp3_mex(D,x,y,z);
st3=toc;

% MEX化したhns_interp3_uint8を使用
tic;
D4=hns_interp3_uint8(D,x,y,z);
st4=toc;

for n=1:4;
    switch n;
        case 1;DD=D1;st=['interp3.mの処理時間は',sprintf('%0.2f秒です',st1)];
        case 2;DD=D2;st=['hns\_interp3.mの処理時間は',sprintf('%0.2f秒です',st2)];
        case 3;DD=D3;st=['hns\_interp3\_mexの処理時間は',sprintf('%0.2f秒です',st3)];
        case 4;DD=D3;st=['hns\_interp3\_uint8の処理時間は',sprintf('%0.2f秒です',st4)];
    end;
    figure('color',[1,1,1]);
    for m=1:3
        subplot(1,3,m);
        switch m
            case 1;imagesc(squeeze(DD(round(siz2(1)/2),:,:)));
            case 2;imagesc(squeeze(DD(:,round(siz2(2)/2),:)));
                   title(st);
            case 3;imagesc(squeeze(DD(:,:,round(siz2(3)/2))));
        end;
        daspect([1,1,1]);
    end;
    colormap(gray);
end;