module draw_loads_mod
	
	use basicmodule
	implicit none
	
	contains
	
	subroutine add_aloads(j)
		integer	:: l,j,t
		do l=1,nfac
			do t=1,capT
				us(t,j)=us(t,j)+sum(loads(:,t,l,j)*us(t:t-nld+1:-1,n+l))
			enddo
		enddo
		us(1:capT,j)=us(1:capT,j)+eos(:,j)
	end subroutine

	subroutine sub_aloads(j)
		integer	:: l,j,t
		do l=1,nfac
			do t=1,capT
				us(t,j)=us(t,j)-sum(loads(:,t,l,j)*us(t:t-nld+1:-1,n+l))
			enddo
		enddo
		us(1:capT,j)=us(1:capT,j)-eos(:,j)
	end subroutine	
	
	subroutine draw_aloads(j)
		integer, parameter	:: nfac=1,nstate=nphi+nfac*nld
		integer	:: j
		real	:: V(nstate,nstate),varstate(nfac*nld),uvarstate(nfac*nld)
		real	:: h(nstate),Vh(nstate),ev
		real	:: shat(nstate,2),s0(nstate),s0s(nstate,0:capT)
		real	:: shats(nstate,2,capT), Vs(nstate,nstate,capT),z(nstate),as(2,capT),bs(nstate,capT),f(nstate,nstate)
		real	:: r(nstate,2),tmp,yeps(capT)
		integer	:: t,i,l
		V=0
		V(1,1)=s2s(1,j)
		do l=1,nfac
			varstate((l-1)*nld+1:l*nld)=min(s2loads(:,2,l,j),maxs2loadd)
			uvarstate((l-1)*nld+1:l*nld)=s2loads(:,1,l,j)
		enddo
		do i=nphi+1,nstate
			V(i,i)=uvarstate(i-nphi)
		enddo
		call rnnoa(z)
		s0(1)=sqrt(V(1,1))*z(1)
		s0(2:nphi+1-1)=0
		s0(nphi+1:nstate)=sqrt(uvarstate)*z(nphi+1:nstate)
		shat=0
		shat(1:nphi,1)=us(0:-nphi+1:-1,j)
		shat(nphi+1:,1)=shat(nphi+1:,1)+[s2loads_prior(:,1,1,:)]
		t=0
		call transstate(shat(:,1))
		h=0; h(1)=1
		call rnnoa(yeps)
		do t=1,capT
			s0s(:,t)=s0
			shats(:,:,t)=shat
			Vs(:,:,t)=V	
			h(nphi+1:)=[us(t:t-nld+1:-1,n+1:)]
!			Vh=matmul(V,h)
			Vh=V(:,1)+matmul(V(:,nphi+1:),h(nphi+1:))
!			ev=sum(Vh*h)
			ev=Vh(1)+sum(Vh(nphi+1:)*h(nphi+1:))+s2eos(t,j)
			ev=1.0/ev
			bs(:,t)=ev*Vh
!			as(:,t)=[us(t,j)-sum(h*shat(:,1)),sum(h*(s0-shat(:,2)))]*ev
			as(:,t)=[us(t,j)-shat(1,1)-sum(h(nphi+1:)*shat(nphi+1:,1)),s0(1)-shat(1,2)+sum(h(nphi+1:)*(s0(nphi+1:)-shat(nphi+1:,2)))+sqrt(s2eos(t,j))*yeps(t)]*ev

			shat(:,1)=shat(:,1)+as(1,t)*Vh
			shat(:,2)=shat(:,2)+as(2,t)*Vh
			do i=1,nstate
				V(:,i)=V(:,i)-bs(i,t)*Vh
			enddo
			if(t==capT) exit
			call transstate(s0)			
			call rnnoa(z(1:1))
			s0(1)=s0(1)+sqrt(s2s(t+1,j))*z(1)
			call rnnoa(z(nphi+1:))
			s0(nphi+1:)=s0(nphi+1:)+sqrt(varstate)*z(nphi+1:)
			call transstate(shat(:,1))
			call transstate(shat(:,2))
			do i=1,nstate
				F(1,i)=sum(phis(:,t+1,j)*V(1:nphi,i))	
			enddo
			F(2:nphi,:)=V(1:nphi-1,:)
			F(nphi+1:,:)=V(nphi+1:,:)
			
			do i=1,nstate
				V(i,1)=sum(phis(:,t+1,j)*F(i,1:nphi))	
			enddo
			V(:,2:nphi)=F(:,1:nphi-1)
			V(:,nphi+1:)=F(:,nphi+1:)
			V(1,1)=V(1,1)+s2s(t+1,j)
			do i=nphi+1,nstate
				V(i,i)=V(i,i)+varstate(i-nphi)
			enddo
		enddo
		r=0
		do t=capT,1,-1
			h(nphi+1:)=[us(t:t-nld+1:-1,n+1:)]
			r(:,1)=r(:,1)+(as(1,t)-sum(bs(:,t)*r(:,1)))*h
			r(:,2)=r(:,2)+(as(2,t)-sum(bs(:,t)*r(:,2)))*h
			shat=shats(:,:,t)+matmul(Vs(:,:,t),r)
			do l=1,nfac
				loads(:,t,l,j)=shat((l-1)*nld+nphi+1:l*nld+nphi+1-1,1)-shat((l-1)*nld+nphi+1:l*nld+nphi+1-1,2)+s0s((l-1)*nld+nphi+1:l*nld+nphi+1-1,t)
				us(t,j)=us(t,j)-sum(loads(:,t,l,j)*us(t:t-nld+1:-1,n+l))
			enddo
			eos(t,j)=us(t,j)-(shat(1,1)-shat(1,2)+s0s(1,t))
			us(t,j)=us(t,j)-eos(t,j)
			do i=1,2
				tmp=r(1,i)
				r(1:nphi-1,i)=r(2:nphi,i)
				r(nphi,i)=0
				r(1:nphi,i)=r(1:nphi,i)+tmp*phis(:,t,j)
			enddo			
		enddo
	contains
		subroutine transstate(s)
			real	:: s(nstate),tmp
			tmp=sum(phis(:,t+1,j)*s(1:nphi))
			s(2:nphi)=s(1:nphi-1)
			s(1)=tmp
		end subroutine
	end subroutine
	

	function get_aloads_ll(j,s2load,s2load_prior) result(val)
		integer, parameter	:: nfac=1,nstate=nphi+nfac*nld
		integer	:: j
		real	:: s2load(nld,2,nfac),s2load_prior(nld,2,2,nfac),val
		real	:: V(nstate,nstate),varstate(nfac*nld),uvarstate(nfac*nld)
		real	:: shat(nstate),h(nstate),Vh(nstate),ev,e,F(nstate,nstate)
		real	:: tmp
		integer	:: t,i,l

		val=0
		do l=1,nfac
			val=val-.5*sum((log(s2load(:,2,l))-s2load_prior(:,1,2,l))**2/s2load_prior(:,2,2,l))
		enddo
		val=val+getll_startup(j,s2s(0,j))

		do l=1,nfac
			uvarstate((l-1)*nld+1:l*nld)=s2load(:,1,l)
			varstate((l-1)*nld+1:l*nld)=min(s2load(:,2,l),maxs2loadd)
		enddo
		V=0
		V(1,1)=s2s(1,j)
		do i=nphi+1,nstate
			V(i,i)=uvarstate(i-nphi)
		enddo
		shat=0
		shat(1:nphi)=us(0:-nphi+1:-1,j)
		t=0
		call transstate(shat)
		shat(nphi+1:)=shat(nphi+1:)+[s2loads_prior(:,1,1,:)]
		do t=1,capT
			h(nphi+1:)=[us(t:t-nld+1:-1,n+1:)]
			Vh=V(:,1)+matmul(V(:,nphi+1:),h(nphi+1:))
			ev=Vh(1)+sum(Vh(nphi+1:)*h(nphi+1:))+s2eos(t,j)
			ev=1.0/ev
			e=us(t,j)-shat(1)-sum(h(nphi+1:)*shat(nphi+1:))
			val=val-.5*e**2*ev+.5*log(ev)
			if(t==capT) exit

			shat=shat+e*ev*Vh
			do i=1,nstate
				V(:,i)=V(:,i)-ev*Vh(i)*Vh
			enddo
			call transstate(shat)
			do i=1,nstate
				F(1,i)=sum(phis(:,t+1,j)*V(1:nphi,i))	
			enddo
			F(2:nphi,:)=V(1:nphi-1,:)
			F(nphi+1:,:)=V(nphi+1:,:)
			
			do i=1,nstate
				V(i,1)=sum(phis(:,t+1,j)*F(i,1:nphi))	
			enddo
			V(:,2:nphi)=F(:,1:nphi-1)
			V(:,nphi+1:)=F(:,nphi+1:)
			V(1,1)=V(1,1)+s2s(t+1,j)
			do i=nphi+1,nstate
				V(i,i)=V(i,i)+varstate(i-nphi)
			enddo
		enddo

	contains
		subroutine transstate(s)
			real	:: s(nstate),tmp
			integer	:: i
			tmp=sum(phis(:,t+1,j)*s(1:nphi))
			s(2:nphi)=s(1:nphi-1)
			s(1)=tmp
		end subroutine
	end function

	subroutine draw_as2load(j)
		integer	:: l,j
		real	:: new_s(nld,2,nfac)
		real	:: ll,z(nld),u(1)
		do l=1,nfac
			new_s(:,1,l)=s2loads(:,1,l,j)
			call rnnoa(z)
			new_s(:,2,l)=s2loads(:,2,l,j)*exp(mhscales(0,oas2load)*sqrt(s2loads_prior(:,2,2,l))*z/sqrt(real(capT)))
		enddo
		ll=get_aloads_ll(j,new_s,s2loads_prior)-get_aloads_ll(j,s2loads(:,:,:,j),s2loads_prior)
		call rnun(u)
		if(u(1)<exp(ll)) then
			s2loads(:,:,:,j)=new_s		
			accj(0,oas2load,j)=accj(0,oas2load,j)+1
		endif
	end subroutine
	
			
	subroutine draw_as2load_prior
		integer	:: l,j,i
		real	:: new_s(nld,2,nfac,n)
		real	:: ll,z(nld,2),u(1)
		real	:: new_prior(nld,2,2,nfac),e(nld,nfac)
		do l=1,nfac
			if(imh==0) then
				call rnnoa(z)
				z=z*mhscales(0,oas2load_prior)
			else
				z=0
				if(imh<3) then
					call rnnoa(z(:,imh:imh))
				endif
			endif
			new_prior(:,1,1,l)=s2loads_prior(:,1,1,l)+mhscales(1,oas2load_prior)*sqrt(s2loads_prior_prior(:,1,1,2)/(capT*n))*z(:,1)
			new_prior(:,1,2,l)=s2loads_prior(:,1,2,l)+mhscales(2,oas2load_prior)*sqrt(s2loads_prior_prior(:,1,2,2)/(capT*n))*z(:,2)
			if(imh==0) then
				call rnnoa(z)
				z=z*mhscales(0,oas2load_prior)
			else
				z=0
				if(imh>2) then
					call rnnoa(z(:,imh-2:imh-2))
				endif
			endif
			if(l>1) z(1,1)=0
			new_prior(:,2,1,l)=s2loads_prior(:,2,1,l)*exp(mhscales(3,oas2load_prior)*sqrt(s2loads_prior_prior(:,2,1,2)/(capT*n))*z(:,1))
			new_prior(:,2,2,l)=s2loads_prior(:,2,2,l)*exp(mhscales(4,oas2load_prior)*sqrt(s2loads_prior_prior(:,2,2,2)/(capT*n))*z(:,2))
		enddo
		
		ll=get_s2load_prior_ll(new_prior)-get_s2load_prior_ll(s2loads_prior)
!$omp parallel do private(l,e) reduction(+:ll)		
		do j=1,n
			do l=1,nfac
				e(:,l)=log(s2loads(:,2,l,j))-s2loads_prior(:,1,2,l)
				e(:,l)=e(:,l)*sqrt(new_prior(:,2,2,l)/s2loads_prior(:,2,2,l))
				e(:,l)=e(:,l)+new_prior(:,1,2,l)
			enddo
			new_s(:,2,:,j)=exp(e)
			do l=1,nfac
				new_s(:,1,l,j)=new_prior(:,2,1,l)
			enddo
			ll=ll+get_aloads_ll(j,new_s(:,:,:,j),new_prior)-get_aloads_ll(j,s2loads(:,:,:,j),s2loads_prior)
		enddo
		call rnun(u)
		if(u(1)<exp(ll)) then
			s2loads=new_s
			s2loads_prior=new_prior
			acc(imh,oas2load_prior)=acc(imh,oas2load_prior)+1
		endif
	end subroutine

	function get_s2load_prior_ll(prior) result(val)
		real	:: prior(nld,2,2,nfac),val
		integer	:: l,i
		val=0
		do l=1,nfac
			val=val-.5*sum((prior(2:,1,1,l)-s2loads_prior_prior(2:,1,1,1))**2/s2loads_prior_prior(2:,1,1,2))
			val=val-.5*sum((prior(:,1,2,l)-s2loads_prior_prior(:,1,2,1))**2/s2loads_prior_prior(:,1,2,2))
			if(l==1) then
				val=val-.5*sum((log(prior(:,2,1,l))-s2loads_prior_prior(:,2,1,1))**2/s2loads_prior_prior(:,2,1,2))
			else
				val=val-.5*sum((log(prior(2:,2,1,l))-s2loads_prior_prior(2:,2,1,1))**2/s2loads_prior_prior(2:,2,1,2))
			endif
			val=val-.5*sum((log(prior(:,2,2,l))-s2loads_prior_prior(:,2,2,1))**2/s2loads_prior_prior(:,2,2,2))
		enddo
	end function
	
end module
