module marginallikelihood
	use compute
	use dgp
	implicit none
	
	real, allocatable	:: s2coefs_prior_prior0(:,:,:),s2loads_prior_prior0(:,:,:,:),s2loads_prior0(:,:,:,:),s2phi_prior_prior0(:,:,:,:),tdf_prior_prior0(:,:),tdfe_prior_prior0(:,:)
	real				:: hpscale=1

	interface ll_normal
        module procedure ll_normal_s, ll_normal_v, ll_normal_m
	end interface
	
	contains
	
	function ll_normal_s(m,v) result(val)
		real	:: m,v,val
		val=-.5*m**2/v-.5*log(v)
	end function

	function ll_normal_v(m,v) result(val)
		real	:: m(:),v(:),val
		val=-.5*sum(m**2/v)-.5*sum(log(v))
	end function
	
	function ll_normal_m(m,v) result(val)
		real	:: m(:,:),v(:,:),val
		val=-.5*sum(m**2/v)-.5*sum(log(v))
	end function
	
	function getsingleBF(LR1,LR2) result(val) !LR1 are draws from model 1 and contain q2/q1, LR2 are draws from model 2 and contain q1/q2
		real	:: LR1(:),LR2(:),val
		real	:: r,lnr,cn,cd,tmp
		real	:: last_lnr
		integer	:: i

		lnr=.5*maxval(LR2)-.5*maxval(LR1)	! start with solution that would obtain if LR1 and LR2 were both dominated by largest value
		last_lnr=lnr
		do i=1,500
			cn=maxval(lr2)
			cd=max(cn,lnr)
			tmp=cn-cd+log(sum(exp(lr2-cn)/(exp(lr2-cd)+exp(lnr-cd))))

			cn=maxval(lr1)
			cd=max(0.0,lnr+cn)
			lnr=tmp-(cn-cd+log(sum(exp(lr1-cn)/(exp(-cd)+exp(lnr+lr1-cd)))))
			if(abs(last_lnr-lnr)<1E-3) exit
			tmp=lnr
			lnr=.9*lnr+.1*last_lnr		! only go half way to avoid oscillations
			last_lnr=tmp
		enddo
		if(i>500) then
			print *,"problem in BF convergence",lnr,last_lnr
		endif
		val=lnr
	end function
	
	
	function getll(icase) result(ll)
		integer	:: icase
		real	:: ll
		ll=0
		select case(icase)
		case(1)		! HP
			ll=ll_normal(s2phi_prior(:,1,:)-s2phi_prior_prior(:,1,:,1),s2phi_prior_prior(:,1,:,2))
			ll=ll+ll_normal(log(s2phi_prior(:,2,:))-s2phi_prior_prior(:,2,:,1),s2phi_prior_prior(:,2,:,2))

			ll=ll+ll_normal(log(tdf_var)-tdf_prior_prior(1,1),tdf_prior_prior(2,1))
			ll=ll+ll_normal(tdfs(0)-tdf_prior_prior(1,0),tdf_prior_prior(2,0))
						
			ll=ll+ll_normal(log(tdfe_var)-tdfe_prior_prior(1,1),tdfe_prior_prior(2,1))
			ll=ll+ll_normal(tdfes(0)-tdfe_prior_prior(1,0),tdfe_prior_prior(2,0))

			ll=ll+ll_normal(log(s2coefs_var)-s2coefs_prior_prior(:,1,1),s2coefs_prior_prior(:,2,1))
			ll=ll+ll_normal(s2coefs(1:ou0,0)-tgs2-s2coefs_prior_prior(1:ou0,1,0),s2coefs_prior_prior(1:ou0,2,0))
			ll=ll+ll_normal(s2coefs(oup:,0),spread(s2coefs_prior_prior(oup,2,0),1,size(s2coefs(oup:,0))))
						
			ll=ll+ll_normal(s2loads_prior(2:,1,1,1)-s2loads_prior_prior(2:,1,1,1),s2loads_prior_prior(2:,1,1,2))
			ll=ll+ll_normal(s2loads_prior(:,1,2,1)-s2loads_prior_prior(:,1,2,1),s2loads_prior_prior(:,1,2,2))
			ll=ll+ll_normal(log(s2loads_prior(:,2,:,1))-s2loads_prior_prior(:,2,:,1),s2loads_prior_prior(:,2,:,2))

		case(2)		! student-t	
			ll=ll_normal(tdfs(0)-tdf_prior_prior(1,0),tdf_prior_prior(2,0))
		case(3)		! additive outlier
			ll=ll_normal(s2coefs(oeo,0)-tgs2-s2coefs_prior_prior(oeo,1,0),s2coefs_prior_prior(oeo,2,0))
			ll=ll+ll_normal(s2coefs(oeo,n+1)-tgs2-global_eo_mud_prior(oeo,1),global_eo_mud_prior(oeo,2))
		case(4)		! stochastic vola
			ll=ll_normal(s2coefs(oup:,0),spread(s2coefs_prior_prior(oup,2,0),1,size(s2coefs(oup:,0))))
			ll=ll+ll_normal(log(s2coefs_var(oup))-s2coefs_prior_prior(oup,1,1),s2coefs_prior_prior(oup,2,1))
		case(5)	! tvp mu
			ll=ll_normal(s2coefs(omud,0)-s2coefs_prior_prior(omud,1,0)-tgs2,s2coefs_prior_prior(omud,2,0))
			ll=ll+ll_normal(s2coefs(omud,n+1)-tgs2-global_eo_mud_prior(omud,1),global_eo_mud_prior(omud,2))
		case(6)		! tvp phi
			ll=ll_normal(s2phi_prior(:,1,2)-s2phi_prior_prior(:,1,2,1),s2phi_prior_prior(:,1,2,2))
		case(7)		! tvp loads
			ll=ll_normal(s2loads_prior(:,1,2,1)-s2loads_prior_prior(:,1,2,1),s2loads_prior_prior(:,1,2,2))
		case(8)		! factor
			ll=ll_normal(s2coefs(ou0,n+1)-tgs2-s2fac_prior(1),s2fac_prior(2))
		end select					
	end function
	
	subroutine setmodel(icase,g,go)
		integer	:: icase
		real	:: g(:)
		real, optional	:: go(:)
		select case(icase)
		case(1)		! HP
			if(present(go)) go=hpscale
			hpscale=g(1)
			s2phi_prior_prior(:,:,:,2)=hpscale*s2phi_prior_prior0(:,:,:,2)
			tdf_prior_prior(2,:)=hpscale*tdf_prior_prior0(2,:)
			tdfe_prior_prior(2,:)=hpscale*tdfe_prior_prior0(2,:)
			s2coefs_prior_prior(:,2,:)=hpscale*s2coefs_prior_prior0(:,2,:)
			s2loads_prior_prior(:,:,:,2)=hpscale*s2loads_prior_prior0(:,:,:,2)
		case(2)		! student-t
			if(present(go)) go(1)=tdf_prior_prior(1,0)
			tdf_prior_prior(1,0)=g(1)
		case(3)		! additive outlier
			if(present(go)) then
				go(1)=s2coefs_prior_prior(oeo,1,0)
				go(2)=global_eo_mud_prior(oeo,1)
			endif
			s2coefs_prior_prior(oeo,1,0)=g(1)
			global_eo_mud_prior(oeo,1)=g(2)
		case(4)		! stochastic vola
			if(present(go)) then
				go(1)=s2coefs_prior_prior(oup,2,0)
				go(2)=s2coefs_prior_prior(oup,1,1)
			endif
			s2coefs_prior_prior(oup,2,0)=g(1)
			s2coefs_prior_prior(oup,1,1)=g(2)
		case(5)		! tvp mu
			if(present(go)) then
				go(1)=s2coefs_prior_prior(omud,1,0)
				go(2)=global_eo_mud_prior(omud,1)
			endif
			s2coefs_prior_prior(omud,1,0)=g(1)
			global_eo_mud_prior(omud,1)=g(2)
		case(6)		! tvp phi
			if(present(go)) go=s2phi_prior_prior(:,1,2,1)
			s2phi_prior_prior(:,1,2,1)=g
		case(7)		! tvp loads
			if(present(go)) go(1:nld)=s2loads_prior_prior(:,1,2,1)
			s2loads_prior_prior(:,1,2,1)=g(1:nld)
		case(8)		! factor
			if(present(go)) then
				go(1)=s2fac_prior(1)
			endif
			s2fac_prior(1)=g(1)
		end select					
	end subroutine
		
	function getlr(icase,g1,g0) result(lr)
		integer	:: icase
		real	:: g1(:),g0(:),lr,go(size(g1))
		
		call setmodel(icase,g1,go)
		lr=getll(icase)
		call setmodel(icase,g0)
		lr=lr-getll(icase)
		call setmodel(icase,go)
	end function
	
	subroutine setml(icase,finBF)
		integer	:: icase
		real	:: finBF
		integer, parameter	:: nsteps=3, nsim=10000
		real	:: LRs(2,nsteps,nsim),BFs(nsteps-1)
		integer	:: istep,l,j
		real	:: s,ll,i
		real	:: gs(nphi,nsteps),g(nphi)
		real, allocatable	:: ts(:,:)
		call execinML("clear tsa")
		gs=0.0/0.0
		call prep_real
		s2coefs_prior_prior0=s2coefs_prior_prior
		s2loads_prior_prior0=s2loads_prior_prior
		s2loads_prior0=s2loads_prior
		s2phi_prior_prior0=s2phi_prior_prior
		tdf_prior_prior0=tdf_prior_prior; tdfe_prior_prior0=tdfe_prior_prior
		
		do istep=1,nsteps
			s=(istep-1.0)/(nsteps-1)	
			select case(icase)
			case(1)		! HP
				gs(1,istep)=1-.5*s
			case(2)		! student-t
				gs(1,istep)=tdf_prior_prior(1,0)+3*s*sqrt(tdf_prior_prior(2,0))
			case(3)		! additive outlier
				gs(1,istep)=s2coefs_prior_prior(oeo,1,0)-3*s*sqrt(s2coefs_prior_prior(oeo,2,0))
				gs(2,istep)=global_eo_mud_prior(oeo,1)-3*s*sqrt(global_eo_mud_prior(oeo,2))
			case(4)		! stochastic vola
				gs(1,istep)=(1-.5*s)*s2coefs_prior_prior(oup,2,0)
				gs(2,istep)=s2coefs_prior_prior(oup,1,1)-3*s*sqrt(s2coefs_prior_prior(oup,2,1))
			case(5)		! tvp mu
				gs(1,istep)=s2coefs_prior_prior(omud,1,0)-3*s*sqrt(s2coefs_prior_prior(omud,2,0))
				gs(2,istep)=global_eo_mud_prior(omud,1)-3*s*sqrt(global_eo_mud_prior(omud,2))
			case(6)		! tvp phi
				gs(:,istep)=s2phi_prior_prior(:,1,2,1)-3*s*sqrt(s2phi_prior_prior(:,1,2,2))
			case(7)		! tvp loads
				gs(1:nld,istep)=s2loads_prior_prior(:,1,2,1)-3*s*sqrt(s2loads_prior_prior(:,1,2,2))
			case(8)		! factor
				gs(1,istep)=s2fac_prior(1)-3*s*sqrt(s2fac_prior(2))
			end select					
		enddo
		print *,"gs used for BF calc"
		call mdisp(gs(1:count(isfinite(gs(:,1))),:))
		lrs=0.0/0.0		
		BFs=0.0
		do istep=1,nsteps
			call prep_real
			call setmodel(icase,gs(:,istep))
			call dgp_s2s			! make sure s2s and s2eos are compatible with bounds
			acc=0;accj=0; mhscales0=mhscales
			mhscales(0,:)=mhscales0(0,:)*2.0
			do l=-min(200,n),0
				call draw_all(l)	! 200 draws without hp and factor
			enddo
			do l=1,nsim/3
				mhscales(0,:)=mhscales0(0,:)*5.0**(1-l/(nsim/3.0))
				call draw_all(l)
				if(.not.all(isfinite(us(:,1)))) then
					print *,"problem in sampler; aborting 1",icase,istep,l
					stop
				endif
			enddo
			acc=0;accj=0
			do l=1,max(nsim/3,5)
				call draw_all_adjust(l)
				if(mod(l,200)==0) then
					call adjust_mhscales(200)
					acc=0;accj=0
				endif
				if(.not.all(isfinite(us(:,1)))) then
					print *,"problem in sampler; aborting 2",icase,istep,l
					return
				endif
			enddo
			acc=0;accj=0
			do l=1,nsim
				call draw_all(l)
				if(.not.all(isfinite(us(:,1)))) then
					print *,"problem in sampler; aborting 3",icase,istep,l
					return
				endif
				if(istep>1) lrs(1,istep,l)=getlr(icase,gs(:,istep-1),gs(:,istep))
				if(istep<nsteps) lrs(2,istep,l)=getlr(icase,gs(:,istep+1),gs(:,istep))
			enddo
			if(istep>1) then
				BFs(istep-1)=getsingleBF(LRs(2,istep-1,:),LRs(1,istep,:))
				call mdisp(LRs(2,istep-1,1:nsim:10).cvr.LRs(1,istep,1:nsim:10))
			endif
		enddo
		call mdisp(BFs)
		finBF=sum(BFs)
		print *,finBF
		call printtime		
	end subroutine
				
end module
	