function [out1,out2,out3] = func(flag,s,x,e,beta,ies,kappa_y,kappa_q,delta,Lambda_pi,Lambda_y,s_d,CqCy,ygapss,gain,rho_u,rho_rn,flag_zlb)

[n_nodes, n_states] = size(s);
[~, n_controls] = size(x);


psilag_idx  = 2;
mulag_idx   = 4;
u_idx       = 1;
rn_idx      = 3;
% qgap_idx    = 5; 

psilag      = s(:,psilag_idx);
mulag       = s(:,mulag_idx);
u           = s(:,u_idx);
rn          = s(:,rn_idx); 

% x = [ygap, pi, i, psi, lambda]
ygap_idx    = 1;
pi_idx      = 2;
i_idx       = 3;
psi_idx     = 4;
lambda_idx  = 5;

ygap        = x(:,ygap_idx);
pi          = x(:,pi_idx);
i           = x(:,i_idx);
psi         = x(:,psi_idx);
lambda      = x(:,lambda_idx);

% e = [e_u, e_rn];
e_u_idx = 1;
e_rn_idx = 2;

exprn = rn/(1-rho_rn); 

switch flag 
    case 'b' % bounds on control variables
        out1 = -inf*ones(size(x));      % lower bound
        out2 = inf*ones(size(x));       % upper bound
        
        if flag_zlb
            out1(:,i_idx) = -(1-beta)/beta*100; % zero-lower bound 
        end
        
    case 'f' % reward (states in rows)
%         out1 = zeros(n_nodes,1);                        % f
        out2 = zeros(n_nodes,n_controls);               % fx
        out3 = zeros(n_nodes,n_controls,n_controls);    % fxx

        out1 = -(Lambda_pi * pi.^2 + Lambda_y * ygap.^2) ...
            + (psi - psilag).*pi ...
            - psi .* (kappa_y*ygap + kappa_q*s_d + u) ...
            + lambda .* (ygap - ygapss + ies*(i - exprn/ies)) ...
            + mulag * ies .* (i - pi);
        
        % derivative of reward (states in rows) w.r.t controls (in columns)
        out2(:,ygap_idx)        = -2*Lambda_y * ygap - psi*kappa_y + lambda;
        out2(:,pi_idx)          = -2*Lambda_pi * pi + (psi - psilag) - mulag*ies;
        out2(:,i_idx)           = lambda*ies + mulag*ies;
        out2(:,psi_idx)         = pi - kappa_y*ygap - kappa_q*s_d - u;
        out2(:,lambda_idx)      = ygap - ygapss + ies*(i - exprn/ies);
        
        % 2nd derivative of reward w.r.t. controls 
        out3(:,ygap_idx,ygap_idx)       = -2*Lambda_y;
        out3(:,ygap_idx,psi_idx)        = -kappa_y;
        out3(:,ygap_idx,lambda_idx)     = 1;
        out3(:,pi_idx,pi_idx)           = -2*Lambda_pi;
        out3(:,pi_idx,psi_idx)          = 1;
        out3(:,i_idx,lambda_idx)        = ies;
        out3(:,psi_idx,pi_idx)          = 1;
        out3(:,psi_idx,ygap_idx)        = -kappa_y;
        out3(:,lambda_idx,ygap_idx)     = 1;
        out3(:,lambda_idx,i_idx)        = ies;
        
    case 'g' % state transition function
        out1 = zeros(n_nodes,n_states);                        % g
        out2 = zeros(n_nodes,n_states,n_controls);             % gx
        out3 = zeros(n_nodes,n_states,n_controls,n_controls);  % gxx
        
        out1(:,psilag_idx)     = psi;
        out1(:,mulag_idx)      = 1/beta*(lambda + mulag); 
        out1(:,u_idx)          = -(1-rho_u)*kappa_q*s_d + rho_u * u + e(:,e_u_idx);
        out1(:,rn_idx)         = rho_rn * rn + e(:,e_rn_idx); 
        
        out2(:,psilag_idx,psi_idx)      = 1;
        out2(:,mulag_idx,lambda_idx)    = 1/beta; 
        
end 

end