% function polybayes(n,deg,plotlog)
%
% n = number of points
% deg = max degree of polynomial
% plotlog = 1 plots log evidence, =0 does't plot it
%
% This function reproduces the plots for Bayesian polynomial
% regression in my talk.
%
% (c) Zoubin Ghahramani 2000

function genreg(n,deg,plotlog)
if nargin<1,
	  n=8;
end;
if nargin<2,
	  deg=5;
end;
if nargin<3,
	  plotlog=0;
end;

x=rand(n,1)*10; % input points
e=randn(n,1); % noise
y=(x-4).^2 + 3*e; % actual function

xx=[-2:0.1:12]'; % uniform grid for plotting
m=length(xx);

figure(1);
set(gcf,'Position',[8   220   560   420]);
clf;

OMAXb=deg; 
OMAX=deg;

sensible=1;
if sensible, % some choices for the parameter priors
    gamm=0.001; % controls prior variance of Gaussian distribution on
                % polynomial coeffiecients
    beta=0.00001;  % shape parameter for Gamma distribution on noise
                   % precision (= inverse variance)
    alpha=0.1; % scale parameter for Gamma distribution on noise precision 
  else
    gamm=0.1; 
    beta=0.1; 
    alpha=0.1;
end;
lnE=0;

lne=[];

for i=0:OMAXb,
  X=ones(n,1);
  XX=ones(m,1);
  for j=1:i,
    X=[X x.^j];
    XX=[XX xx.^j];
  end;

  EXX=X'*X/n;
  EYY=y'*y/n;
  EYX=y'*X/n;
  d=i+1;
  IBB=inv(EXX+ (gamm/n)*eye(d));
  err=(EYY - EYX*IBB*EYX');
  fprintf('param=%g err=%g \n',d,err);
  % COMPUTE THE EVIDENCE (i.e. MARGINAL LIKELIHOOD)
  lnE= -(n/2)*log(2*pi) + alpha*log(beta) -gammaln(alpha) + (d/2)*log(gamm/n) - 0.5*log(det(EXX + (gamm/n)*eye(d)))+gammaln(n/2+alpha) ...
      -(n/2+alpha)*log(beta + (n/2)*err); 
  lne=[lne lnE];
  
  % Generate Samples from Posterior Distrubution of Polynomials
  if d>0,
    alp=alpha+n/2;
    bet=beta+0.5*n*EYY-0.5*n*EYX*inv(EXX+ (gamm/n)*eye(d))*EYX'; % '
    subplot(2,ceil((OMAX+1)/2),(i+1));
    if 0,
      for j=1:20,
      rho=gamrnd(alp,1/bet);
      IB=IBB/(rho*n);
      SIB=sqrtm(IB);
      A=rho*n*EYX'; %'
      w=IB*A+SIB*randn(d,1);
      yhat=XX*w;
      hold on;
      plot(xx,yhat,'g-','LineWidth',1);
    end;
  end;
  end;
end;


for i=0:OMAX,
  p1=subplot(2,ceil((OMAX+1)/2),(i+1));
  X=ones(n,1);
  XX=ones(m,1);
  for j=1:i,
    X=[X x.^j];
    XX=[XX xx.^j];
  end;
  bet=X\y;
  yhat=XX*bet;
  plot(xx,yhat,'b-','LineWidth',2);
  s=['M = ' num2str(i)];
  ss=title(s);
  set(ss,'FontSize',18);
  hold on;
  set(p1,'FontSize',15);
  a1=plot(x,y,'ro','MarkerFaceColor','r');
  axis([-2 12 -20 50]);
end;

% PLOT THE EVIDENCE

% pause;
% return;
figure(2);
set(gcf,'Position',[581   220   417   399]);
PP=exp(lne);
PP=PP/sum(PP);

% bar(0:OMAX,PP,'-o')
bar(0:OMAX,PP)
axis([-0.5 OMAX+0.5 0 1]);

set(gca,'FontSize',16);
aa=xlabel('M'); set(aa,'FontSize',20);
aa=ylabel('P(D|M)'); set(aa,'FontSize',20);

if plotlog,
figure(3);
plot(lne,'.-');
aa=xlabel('M'); set(aa,'FontSize',20);
aa=ylabel('log P(D|M)'); set(aa,'FontSize',20);
end;
