module globals
	use myfuncs
	use dotops
	use dispmodule
    
    real, parameter     :: alpha=0.05
    
	integer, parameter	:: T=2500
    real, parameter     :: thmin=.001, thmax=200
    real, allocatable   :: impgrid(:), chkgrid(:)
    integer             :: nimp, nchk
    
    integer, parameter  :: nGQ=100
    real                :: GQxw(nGQ,2)
    
	integer, parameter	:: nsim=400000 
    integer, parameter  :: nY=4
    real                :: Ys(nY,nsim), densimp(nsim), densalt(nsim)
    real, allocatable   :: denschk(:,:)
    
    real                :: HPDCI(2,nsim)
    enum, bind(C)
        enumerator      :: ieff=1, iaug, iHPD, iahb, iah2, nCSp1
!         enumerator      :: ieff=1, nCSp1
    end enum
    integer, parameter  :: nCS=nCSp1-1,nah=2
    character(len=8)    :: CSnames(nCS)=["eff","augment","HPD","Hansen","Andrews"]
    
    logical, allocatable   :: CS(:,:,:)
    integer             :: globall
!$OMP THREADPRIVATE(globall) 
    external            :: HPDfunc, Leftfunc
    
end module

module compute
	use globals
	implicit none
	
	contains
	
    subroutine prep
        integer, parameter  :: n=800, nimp0=100
        integer :: i
        
        impgrid=[(thmin,i=1,10),(thmin+(thmax-thmin)*(i/real(nimp0))**2,i=1,nimp0),(thmax,i=1,10)]
        nimp=size(impgrid)
        
        chkgrid=[(thmin+(thmax-thmin)*(i/real(n)),i=0,n)]
        nchk=size(chkgrid)
        allocate(denschk(nchk,nsim))
        allocate(CS(nchk,nsim,nCS))
    end subroutine

     subroutine setY(impind,Y)
        integer :: impind
        real    :: Y(nY)
        integer :: i
        real    :: J(T+1),r
        
        r=1-impgrid(impind)/T
        
        call rnnoa(J)
        if(r<1) J(1)=J(1)/sqrt(1-r**2)
        do i=2,T+1
            J(i)=r*J(i-1)+J(i)
        enddo
        J=(J-J(1))/sqrt(real(T))
        Y(1)=sum(J(1:T)*(J(2:T+1)-J(1:T)))
        Y(2)=sum(J(1:T)**2)/(T)
        Y(3)=J(T+1)
        Y(4)=sum(J(1:T))/(T)
    end subroutine   
    
    function getdens(th,Y) result(val)
        real    :: th, Y(nY), val, r, sS2
        val=-th*Y(1)-.5*th**2*Y(2)
        r=1-th/T
        sS2=2*(1-r)+(T-1)*(1-r)**2
		val=val+.5*log((1-r**2)/sS2)+.5*(th*Y(4)+Y(3))**2*th*(1-r)/sS2
    end function
    
    function getmu(Y) result(val)
        real    :: Y(nY),val
        val=-(Y(1)-Y(3)*Y(4))/(Y(2)-Y(4)**2)
    end function
    
    function getsig(Y) result(val)
        real    :: Y(nY),val
        val=1/sqrt(Y(2)-Y(4)**2)
    end function

    elemental function priordens(th) result(val)
        real, intent(in)    :: th
        real   :: val
        val=(th+100)**(-1.1)
    end function
        
    function postdens(th,l) result(val)
        real    :: th,val
        integer :: l
        val=priordens(th)*exp(getdens(th,Ys(:,l))-densimp(l))
    end function

    subroutine setdensY
        integer :: l,i,j
        real    :: Y(nY),hs(nimp), h
        real    :: mu, sig, q0,x
        
        
        do l=1,nsim
            call setY(mod(l,nimp)+1,Y)
            Ys(:,l)=Y
        enddo

        call mkGQxw(GQxw)
        
!$omp parallel do private(Y,j,mu,sig,q0,h,hs,x)
        do l=1,nsim
            Y=Ys(:,l)
            do j=1,nimp
                hs(j)=getdens(impgrid(j),Y)
            enddo
            densimp(l)=logmeanexp(hs)

            mu=getmu(Y)
            sig=getsig(Y)
            q0=gausscdf((thmin-mu)/sig)
            h=0
            do j=1,nGQ
                x=gausscdfinv(q0+GQxw(j,1)*(1-q0))
                h=h+GQxw(j,2)*postdens(mu+sig*x,l)/gaussdens(x)
            enddo
            densalt(l)=(1-q0)*sig*h
            do j=1,nchk
                denschk(j,l)=max(exp(getdens(chkgrid(j),Y)-densimp(l)),1E-100)
            enddo
        enddo
    end subroutine
    
    function getcv(stats,w,alpha) result(val)
        use SVRGP_INT
        real    :: stats(nsim), w(nsim),alpha, val
        real    :: sort(nsim),p
        integer :: inds(nsim),l
        
        inds=[(l,l=1,nsim)]
        call svrgp(stats,sort,inds)
        p=w(inds(nsim))
        do l=nsim-1,1,-1
            if(p+w(inds(l))>alpha*nsim) then
                val=sort(l+1)
                return
            endif
            p=p+w(inds(l))
        enddo
        print *,"error in getcv: weights sum up to less than alpha"
        stop
    end function
    

    subroutine seteff
        real    :: cv(nchk)
        integer :: j,l
       
!$omp parallel do
        do j=1,nchk
            cv(j)=getcv(log(densalt/denschk(j,:)),denschk(j,:),alpha)
        enddo

!$omp parallel do
        do l=1,nsim
            CS(:,l,ieff)=log(densalt(l)/denschk(:,l))<cv
        enddo
    end subroutine
    
    subroutine setHPD
        use neqnf_int
        integer :: l,j
        real    :: leftx(1), HPDx(2)
        real    :: Y(nY), mu, sig
        
        real    :: fnorm
        call erset(0,1,0)
!!$omp parallel do private(Y,mu,sig,leftx,hpdx)
        do l=1,nsim
            mu=chkgrid(maxloc(denschk(:,l),dim=1))
            if(mu==maxval(chkgrid)) mu=getmu(Ys(:,l))
            sig=getsig(Ys(:,l))
            leftx=max(mu+sig,thmin+.1)
            globall=l
            do j=0,int(maxval(chkgrid)),1
                call neqnf(Leftfunc,leftx,xguess=leftx,fnorm=fnorm)
                if(fnorm<0.01) exit
                print *,"fatal error caught, trying new inital value"
                leftx=thmin+j
            enddo
            if(postdens(leftx(1),l)<postdens(thmin,l) .and. fnorm<.01) then
                HPDCI(:,l)=[thmin,leftx]
            else
                hpdx=[max(thmin,mu-sig),mu+sig]
                do j=0,20
                    call neqnf(HPDfunc,hpdx,xguess=hpdx,fnorm=fnorm)
                    if(fnorm<0.01) exit
                    print *,"fatal error caught, trying new inital value"
                    hpdx=[max(thmin,mu-j),mu+j]
                enddo
                if(fnorm>0.01) then
                    print *,"HPD determination failed"
                    call disp(denschk(:,l))
                endif
                HPDCI(:,l)=hpdx
            endif
            CS(:,l,iHPD)=chkgrid>=HPDCI(1,l) .and. chkgrid<=HPDCI(2,l)
        enddo
    end subroutine
    
    subroutine setaug
        real    :: cv(nchk)
        real, allocatable   :: stats(:,:)
        integer :: j,l
       
        allocate(stats(nchk,nsim))
!$omp parallel do
        do j=1,nchk
            stats(j,:)=merge(log(densalt/denschk(j,:)),-100000.0,HPDCI(1,:)>chkgrid(j).or.HPDCI(2,:)<chkgrid(j))
            cv(j)=getcv(stats(j,:),denschk(j,:),alpha)
        enddo
!$omp parallel do
        do l=1,nsim
            CS(:,l,iaug)=(stats(:,l)==-100000.0) .or. (stats(:,l)<cv)
        enddo
    end subroutine

    subroutine setadhoc
        real    :: cv(2,nchk,nah)
        real    :: thhat,Y(nY)
        real, allocatable   :: stats(:,:,:)
        integer :: l,i,j
        
        allocate(stats(nah,nchk,nsim))
!$omp parallel do private(l,thhat,j)
        do i=1,nchk
            do l=1,nsim
                Y=Ys(:,l)
                thhat=getmu(Y)
                stats(1,i,l)=(thhat-chkgrid(i))/getsig(Y)
                stats(2,i,l)=thhat
            enddo
            do j=1,nah
                cv(1,i,j)=-getcv(-stats(j,i,:),denschk(i,:),alpha/2)
                cv(2,i,j)=getcv(stats(j,i,:),denschk(i,:),alpha/2)
            enddo
        enddo
        do j=1,nah
            call disp(cv(:,:,j))
        enddo
        do l=1,nsim
            do j=1,nah
                CS(:,l,iahb+j-1)=stats(j,:,l)>cv(1,:,j) .and. stats(j,:,l)<cv(2,:,j)
            enddo
        enddo
    end subroutine
    
    subroutine evalCS
        integer :: i,j,l
        real    :: out(2,nchk,nCS)
        
        out=0
        do i=1,nCS
            do l=1,nsim
                out(1,:,i)=out(1,:,i)+merge(0.0,denschk(:,l),CS(:,l,i))/nsim
                out(2,:,i)=out(2,:,i)+denschk(:,l)*count(CS(:,l,i))*(chkgrid(2)-chkgrid(1))/nsim
            enddo
            
            print *,CSnames(i)
            call disp(chkgrid.cud.out(:,:,i))
        enddo
     end subroutine
    
    subroutine bet
        real    :: kappa(nchk), pi(nchk),post(nchk)
        real    :: ew(nchk),betprob(nchk),crp(nchk)
		real	:: oldew(nchk), kappafac(nchk),ap
        integer :: l,j,i,lc,ia
        real, parameter :: alphap(4)=[0.05,0.06,.1,.5]
        real    :: out(size(alphap),nchk,nCS,3)

        out=0
        
        do j=iahb,nCS
            print *,"betting against ", CSnames(j)
            do ia=1,size(alphap)
				ap=alphap(ia)
                kappa=0.0001
				kappafac=1
				oldew=0
                do lc=1,5000
                    pi=priordens(chkgrid)+kappa
                    ew=0; betprob=0
!$omp parallel do private(post) reduction(+:ew,betprob)
                    do l=1,nsim
                        post=pi*denschk(:,l)
                        if(sum(post,CS(:,l,j))<(1-alphap(ia))*sum(post) .and. any(CS(:,l,j)) .and. (.not. CS(nchk,l,j)) ) then
                             ew=ew+denschk(:,l)*(1-(1+alphap(ia)/(1-alphap(ia)))*boole(CS(:,l,j)))
							 betprob=betprob+denschk(:,l)
                        endif
                    enddo
                    ew=ew/nsim
					betprob=betprob/nsim

                    if(mod(lc-1,500)==0) then
                        print *,lc
                        call mdisp(chkgrid.cvr.log(kappa).cud.(100*ew).cud.kappafac)
					endif
                    kappa=kappa*exp(-kappafac*ew/betprob)

					where(ew*oldew.ge.0.0)
						kappafac=kappafac*1.03
					elsewhere
						kappafac=kappafac*0.5
					endwhere
					call setbounds(kappafac,1E-7,1.0)
					call setbounds(kappa,exp(-15.0),exp(10.0))
					oldew=ew
				enddo
                print *,"results for ", CSnames(j), " and alphaprime equal to ",alphap(ia)
                call mdisp(chkgrid.cvr.log(kappa).cud.(100*ew).cud.(100*betprob).cud.(100*(alphap(ia)+(1-alphap(ia))*ew/betprob)))
                out(ia,:,j,1)=100*ew
                out(ia,:,j,2)=100*betprob
				out(ia,:,j,3)=100*(ap+(1-ap)*ew/betprob)
                call mdisp(chkgrid .cud. out(:,:,j,1) .cud. out(:,:,j,2) .cud. out(:,:,j,3))
            enddo
        enddo
    
    end subroutine
end module

program mfort
	use globals
	use compute
	implicit none

	call disp_set(advance="double",orient="row")
	call inittime
    call rnopt(6)
    call rnset(14)
    
    call prep
    call setdensY
    call seteff
    call setHPD
    call setaug    
    call setadhoc
    call evalCS
    
    call bet
    
    stop
    
end program

subroutine HPDfunc(X,F,n)
    use compute
    implicit none
    real    :: X(n),F(n)
    integer :: n

    real        :: xc
    integer     :: j
    
    F(1)=log(postdens(x(1),globall)/postdens(x(2),globall))
    F(2)=0
    do j=1,nGQ
        xc=x(1)+GQxw(j,1)*(x(2)-x(1))
        F(2)=F(2)+GQxw(j,2)*postdens(xc,globall)
    enddo
    F(2)=F(2)*(x(2)-x(1))/densalt(globall)-(1-alpha)

end subroutine
    
subroutine Leftfunc(X,F,n)
    use compute
    implicit none
    real    :: X(n),F(n)
    integer :: n

    real                :: xc
    integer             :: j
    
    
    F(1)=0
    do j=1,nGQ
        xc=thmin+GQxw(j,1)*(x(1)-thmin)
        F(1)=F(1)+GQxw(j,2)*postdens(xc,globall)
    enddo
    F(1)=F(1)*(x(1)-thmin)/densalt(globall)-(1-alpha)

end subroutine    
    