%%% AUTOMATED ESTIMATION OF INFLATION MODELS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Estimate Univariate and Multivariate Core Trend models on the most recent
% vintage of PCE data. Summaries and a report are produced.
%
% Version: 2022 Mar 03 - Matlab R2020a
% MODIFIED BY MW Watson , 12/16/2024, 6/13/2025
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Clear memory
clear
close all
clc
rng_seed = 2022;
rng(rng_seed);

i_restart = 1; % Set to 1 to restart from a previous run

% Set directories
data_path = '../../data/';
matdir = 'mat/';

%% DATA

% Extract data
load([data_path 'pcomp_data_monthly.mat']);
calvec = calvec_m;
calds = calds_m;
dp_disagg = dp_disagg_m;
share_avg_xfe = share_avg_xfe_m;
n = size(dp_disagg, 2);
dnobs = size(calvec, 1);
n_h = 12;               % Number of forecast periods

first_fcst = [1984 12]; % First forecast period
last_fcst = calds(dnobs,:);  % Last forecast period
% Find the first forecast period
    t_first = find(calds(:,1) == first_fcst(1) & calds(:,2) == first_fcst(2));
    t_last = find(calds(:,1) == last_fcst(1) & calds(:,2) == last_fcst(2));

if i_restart == 1
    load([[matdir 'intermediate_mct_fcst_pced_' num2str(first_fcst(1)) '_' num2str(first_fcst(2)) '_' num2str(last_fcst(1)) '_' num2str(last_fcst(2)) '.mat']]);
    % find the last forecast period that was completed
    ii = ~isnan(rslt_mct_mean);
    t_end = find(ii,1,'last');
    t_first = t_end+1;
    fprintf('Restarting from t = %d, %d \n', calds(t_first,1), calds(t_first,2));
else
    rslt_fcsts = NaN(n_h,n,dnobs); % Forecasts values for saving
    rslt_actual = NaN(n_h,n,dnobs); % Actual values for saving
    rslt_mct_mean = NaN(dnobs,1);
    rslt_mct_median = NaN(dnobs,1);
end

% Set up priors and other parameters that are held fixed across the dates
% Set number of MA lags, time-aggregated sectors and dependent sectors
n_lags    = repmat(3, [n, 1]);
is_timeag = false(n, 1); %is_timeag(9) = true;
i_depend  = zeros(n, 1);
% Set estimation settings
settings               = struct();
settings.show_progress = false;
settings.n_draw        = 1000;
settings.n_burn        = 1000;
settings.n_thin        = 2;
settings.n_lags    = n_lags;
settings.is_timeag = is_timeag;
settings.i_depend  = i_depend;
settings.n_h       = n_h;
% Priors 
% Set theta/lambda/gamma/ps priors
nper          = 12;
ps_mean       = 1-1/(4*nper);
ps_prior_obs  = 10*nper;
prior         = struct();
prior.prec_MA = 0.1;
prior.nu_lam  = 12;
prior.s2_lam  = 0.25^2/60/nper;
prior.nu_gam  = 60;
prior.s2_gam  = 1/60/nper;
prior.a_ps    = ps_mean*ps_prior_obs;
prior.b_ps    = (1-ps_mean)*ps_prior_obs;
prior.nu_gam  = 60;

% Set up for loop and parfor loop ... I want to save after every n_save periods
% to avoid losing everything if something crashes
n_save = 10;
t_for = t_first;
while t_for <= t_last
    tic
    t1 = t_for;
    t2 = min(t1 + n_save - 1, t_last);
    parfor t = t1:t2
        infl_t = dp_disagg(2:t, :);  % starting at 2 to avoid missing values
        share_t = share_avg_xfe(t, :);
        output_MCT = estimate_MCT_watson(infl_t, prior, settings, rng_seed);
        nl = size(output_MCT.fcst_T,1);
        tmp = mean(output_MCT.fcst_T,3);
        rslt_fcsts(:,:,t) = tmp;
        for i_h = 1:n_h
            if t+i_h <= dnobs
                rslt_actual(i_h,:,t) = dp_disagg(t+i_h,:);
            end
        end
        tmp = squeeze(output_MCT.fcst_T(end,:,:));
        mct_draws = share_t*tmp;
        rslt_mct_mean(t,1) = mean(mct_draws,2);
        rslt_mct_median(t,1) = median(mct_draws,2);
    end
    fprintf('Saving throuth date: %s %s \n', num2str(calds(t2,1)), num2str(calds(t2,2)))
    toc
    % Save the results
    % Variables to save
    var_save = {'calds_m','rslt_fcsts','rslt_actual','rslt_mct_mean','rslt_mct_median'};
    save([matdir 'intermediate_mct_fcst_pced_' num2str(first_fcst(1)) '_' num2str(first_fcst(2)) '_' num2str(last_fcst(1)) '_' num2str(last_fcst(2)) '.mat'],var_save{:});   

    t_for = t2 + 1;
end
 
% Save the results
% Variables to save
var_save = {'calds_m','rslt_fcsts','rslt_actual','rslt_mct_mean','rslt_mct_median'};
save([matdir 'mct_fcst_pced_' num2str(first_fcst(1)) '_' num2str(first_fcst(2)) '_' num2str(last_fcst(1)) '_' num2str(last_fcst(2)) '.mat'],var_save{:});   
