function [out1,out2,out3] = model_func(flag,s,x,e,bbeta,delta,ies,kappa_y,kappa_q,Lambda_pi,Lambda_y,s_d,CqCy,ygap_em_terminal,gain,bbetaU,rho_xi,rho_rn,sig_xi,sig_rn,flag_zlb,gamma)
%         disp(['eval ' flag])

[n_nodes, n_states] = size(s);
[~, n_controls] = size(x);


% globally defined indices for states and controls:
define_idx

psilag      = s(:,psilag_idx);
mulag       = s(:,mulag_idx); 
rn          = s(:,rn_idx); 
logxi       = s(:,logxi_idx);
logqu       = s(:,logqu_idx);
loglagqu    = s(:,loglagqu_idx);

qu = exp(logqu);
lagqu = exp(loglagqu);

% get controls x = [ygap, pi, i, psi, lambda]
ygap        = x(:,ygap_idx);
pi          = x(:,pi_idx);
i           = x(:,i_idx);
psi         = x(:,psi_idx);
lambda      = x(:,lambda_idx);

exprn = rn/(1-rho_rn); % infinite sum of future natural rates

switch flag 
    case 'b' % bounds on control variables
        out1 = -inf*ones(size(x));      % lower bound
        out2 = inf*ones(size(x));       % upper bound
        
        if flag_zlb
            out1(:,i_idx) = -(1-bbeta)/bbeta; % zero-lower bound  
        end
        
    case 'f' % reward (states in rows)
        out1 = zeros(n_nodes,1);                        % f
        out2 = zeros(n_nodes,n_controls);               % fx
        out3 = zeros(n_nodes,n_controls,n_controls);    % fxx
        
        % beta_t
        bbetat = max(0,min(bbetaU, (qu - exp(logxi))./qu / (bbeta * (1-delta)))); 

        quhatgap = ( (1-bbeta*(1-delta))./(1-bbeta*(1-delta)*bbetat) - (1-bbeta*(1-delta))/(1-bbeta*(1-delta)*rho_xi) ).*logxi ...
                    + (bbeta*(1-delta)*(bbetat-1))./(1-bbeta*(1-delta)*bbetat) + s_d;  

        out1 = -(Lambda_pi * pi.^2 + Lambda_y * ygap.^2) ...
            + psi.*(pi-kappa_y*ygap - kappa_q*quhatgap - u-bbeta*psilag) ...
            + lambda .* (ygap - ygap_em_terminal + ies*(i - psilag-exprn) - CqCy*(s_d - quhatgap)) ...
            + mulag * ies .* (i - psilag); 
        
        % derivative of reward (states in rows) w.r.t controls (in columns)
        out2(:,ygap_idx)        = -2*Lambda_y * ygap - psi*kappa_y + lambda;
        out2(:,pi_idx)          = -2*Lambda_pi * pi + psi;
        out2(:,i_idx)           = lambda*ies + mulag*ies;
        out2(:,psi_idx)         = pi - kappa_y*ygap - kappa_q*quhatgap - u-bbeta*psilag;
        out2(:,lambda_idx)      = ygap - ygap_em_terminal + ies*(i -psilag- exprn) - CqCy*(s_d - quhatgap);
        
        % 2nd derivative of reward w.r.t. controls 
        out3(:,ygap_idx,ygap_idx)       = -2*Lambda_y;
        out3(:,ygap_idx,psi_idx)        = -kappa_y;
        out3(:,ygap_idx,lambda_idx)     = 1;
        out3(:,pi_idx,pi_idx)           = -2*Lambda_pi;
        out3(:,pi_idx,psi_idx)          = 1;
        out3(:,i_idx,lambda_idx)        = ies;
        out3(:,psi_idx,pi_idx)          = 1;
        out3(:,psi_idx,ygap_idx)        = -kappa_y;
        out3(:,lambda_idx,ygap_idx)     = 1;
        out3(:,lambda_idx,i_idx)        = ies;
        
    case 'g' % state transition function
        out1 = zeros(n_nodes,n_states);                        % g
        out2 = zeros(n_nodes,n_states,n_controls);             % gx
        out3 = zeros(n_nodes,n_states,n_controls,n_controls);  % gxx
        
        out1(:,psilag_idx)     = (1-gamma)*psilag+gamma*pi;
        out1(:,mulag_idx)      = 1/bbeta*(lambda + mulag);  
        out1(:,rn_idx)         = rho_rn * rn + e(:,e_rn_idx); 
        
        out2(:,psilag_idx,pi_idx)      = gamma;
        out2(:,mulag_idx,lambda_idx)    = 1/bbeta; 
         
        out1(:,loglagqu_idx) = logqu; % = 1/(1 - bbeta*(1-delta)*bbetat)*exp(logxi)
        
        % bbetat: 
        bbetat = max(0, min(bbetaU, (qu - exp(logxi))./qu / (bbeta * (1-delta))));
        
        % 
        logxi_next = rho_xi * logxi + e(:,e_xi_idx); 
        out1(:,logxi_idx) = logxi_next; 
        
        
        bbetat_next = min(bbetaU, bbetat + gain*(qu./lagqu - bbetat));
        
       
        out1(:,logqu_idx) = log(1./(1 - bbeta*(1-delta)*bbetat_next).*exp(logxi_next));
        
end 

end