function [out1,out2,out3] = model_func_re(flag,s,x,e,bbeta,delta,ies,kappa_y,kappa_q,Lambda_pi,Lambda_y,s_d,CqCy,ygapss,gain,bbetaU,rho_xi,rho_rn,sig_xi,sig_rn,sig_u,flag_zlb,gamma)

[n_nodes, n_states] = size(s);
[~, n_controls] = size(x);


define_idx


psilag      = s(:,psilag_idx);
mulag       = s(:,mulag_idx); 
rn          = s(:,rn_idx);  
u           = s(:,u_idx);
% get controls x = [ygap, pi, i, psi, lambda]
ygap        = x(:,ygap_idx);
pi          = x(:,pi_idx);
i           = x(:,i_idx);
psi         = x(:,psi_idx);
lambda      = x(:,lambda_idx); 
 
switch flag 
    case 'b' % bounds on control variables
        out1 = -inf*ones(size(x));      % lower bound
        out2 = inf*ones(size(x));       % upper bound
        
        if flag_zlb
            out1(:,i_idx) = -(1-bbeta)/bbeta; % zero-lower bound 
        end
        
    case 'f' % reward (states in rows)
%         out1 = zeros(n_nodes,1);                        % f
        out2 = zeros(n_nodes,n_controls);               % fx
        out3 = zeros(n_nodes,n_controls,n_controls);    % fxx
        
        out1 = -(Lambda_pi * pi.^2 + Lambda_y * ygap.^2) + psi.*(pi - kappa_y*ygap - kappa_q*s_d - u-bbeta*((1-0.95)*4/400+0.95*(1-gamma)*psilag+0.95*gamma*pi))...
               +lambda.*(ygap + ies*(i -((1-0.95)*4/400+.95*(1-gamma)*psilag+.95*gamma*pi)- rn)) - mulag.*(ygap/bbeta);

        % derivative of reward (states in rows) w.r.t controls (in columns)
        out2(:,ygap_idx)        = -2*Lambda_y*ygap - psi*kappa_y + lambda - mulag/bbeta;
        out2(:,pi_idx)          = -2*Lambda_pi*pi + psi*(1-bbeta*gamma*.95)-lambda*ies*gamma*.95;
        out2(:,i_idx)           = lambda*ies;
        out2(:,psi_idx)         = pi - kappa_y*ygap - kappa_q*s_d - u -bbeta*((1-0.95)*4/400+.95*(1-gamma)*psilag+.95*gamma*pi);
        out2(:,lambda_idx)      = ygap + ies*(i -((1-0.95)*4/400+(1-gamma)*0.95*psilag+gamma*pi*0.95)- rn);
        
        % 2nd derivative of reward w.r.t. controls 
        out3(:,ygap_idx,ygap_idx)       = -2*Lambda_y;
        out3(:,ygap_idx,psi_idx)        = -kappa_y;
        out3(:,ygap_idx,lambda_idx)     = 1;
        out3(:,pi_idx,pi_idx)           = -2*Lambda_pi;
        out3(:,pi_idx,psi_idx)          = 1-bbeta*gamma*.95;
        out3(:,pi_idx,lambda_idx)       = -ies*gamma*.95;
        out3(:,i_idx,lambda_idx)        = ies;
        out3(:,psi_idx,ygap_idx)        = -kappa_y;
        out3(:,psi_idx,pi_idx)          = 1-bbeta*gamma*.95;
        out3(:,lambda_idx,ygap_idx)     = 1;
        out3(:,lambda_idx,i_idx)        = ies;
        out3(:,lambda_idx,pi_idx)       = -ies*gamma*.95;
        
    case 'g' % state transition function
        out1 = zeros(n_nodes,n_states);                        % g
        out2 = zeros(n_nodes,n_states,n_controls);             % gx
        out3 = zeros(n_nodes,n_states,n_controls,n_controls);  % gxx
        
        out1(:,psilag_idx)     = (1-0.95)*4/400+.95*(1-gamma)*psilag+.95*gamma*pi;
        out1(:,mulag_idx)      = lambda;   
        out1(:,rn_idx)         = rho_rn * rn + e(:,e_rn_idx);
        out1(:,u_idx)          = e(:,e_u_idx);
        
        out2(:,psilag_idx,pi_idx)      = gamma*.95;
        out2(:,mulag_idx,lambda_idx)    = 1; 
        
end 

end