function results = imm_pf(X,r,Dist,opts)
    
    % Latent cause-modulated Rescorla-Wagner model.
    %
    % USAGE: results = imm_pf(X,r,Dist,[opts])
    %
    % INPUTS:
    %   X - [N x D] matrix, where X(n,d) specifies the stimulus intensity
    %       of stimulus d on trial n
    %   r - [N x 1] US on each trial
    %   Dist - [N x N] temporal distance between each timepoint
    %   opts (optional) - options structure, containing the following fields:
    %                       .alpha - concentration parameter (default: 0.1)
    %                       .g - temporal scaling parameter (default: 1)
    %                       .psi - [N x 1] binary vector specifying when protein
    %                              synthesis inhibitor is injected (default: all 0)
    %                       .eta - learning rate (default: 0.2)
    %                       .maxIter - maximum number of iterations between
    %                                  each trial (default: 3)
    %                       .w0 - initial weight value (default: 0)
    %                       .sr - US variance (default: 0.4)
    %                       .sx - stimulus variance (default: 1)
    %                       .theta - response threshold (default: 0.3)
    %                       .lambda - response gain (default: 0.005)
    %                       .K - max number of latent sources (default: 15)
    %
    % OUTPUTS:
    %   results - structure containing the following fields:
    %               .V - [N x 1] conditioned response on each trial
    %               .Zp - [N x K] latent cause posterior before observing US
    %               .Z - [N x K] latent cause posterior 
    
    %###### DEFAULT PARAMETERS ##########
    def_opts = struct('alpha',0.1,'g',1,'psi',0,'eta',0.2,'maxIter',3,...
        'w0',0,'sr',0.4,'sx',1,'theta',0.03,'lambda',0.005,'K',15,'M',100);
    F = fieldnames(def_opts);
    
    if nargin < 4 || isempty(opts)
        opts = def_opts;
    else
        for f = 1:length(F)
            if ~isfield(opts,F{f}) || isempty(opts.(F{f}))
                opts.(F{f}) = def_opts.(F{f});
            end
        end
    end
    
    %######## INITIALIZATION ############
    [T D] = size(X);
    alpha = opts.alpha.*ones(T,1);
    psi = opts.psi.*ones(T,1);
    Z = zeros(T,opts.K,opts.M);
    results.V = zeros(T,1);
    W = zeros(D,opts.K) + opts.w0;
    S = Dist.^(-opts.g);
    L = zeros(opts.M,opts.K);
    
    %########## RUN INFERENCE ############
    
    % loop over timepoints
    for t = 1:T
        
        % determine how many EM iterations to perform based on ITI
        if t == T
            nIter = 1;
        else
            nIter = min(opts.maxIter,round(Dist(t,t+1)));
        end
        
        % calculate (unnormalized) posterior, not including reward
        for m = 1:opts.M
            z = squeeze(Z(1:t-1,:,m));
            N = sum(z,1);              % cluster counts
            prior = S(1:t-1,t)'*z;     % ddCRP prior
            prior(find(N==0,1)) = alpha(t);     % probability of new cluster
            L(m,:) = prior./sum(prior);            % normalize prior
            xsum = X(1:t-1,:)'*z;      % [D x K] matrix of feature sums
            nu = opts.sx./(N+opts.sx) + opts.sx;
            xhat = xsum./(N+opts.sx);
            for d = 1:D
                L(m,:) = L(m,:) .* normpdf(X(t,d),xhat(d,:),sqrt(nu)); % likelihood
            end
        end
        
        % reward prediction, before feedback
        post = bsxfun(@rdivide,L,sum(L,2));
        results.V(t) = mean((X(t,:)*W)*post');
        if ~isnan(opts.theta); results.V(t) = 1-normcdf(opts.theta,results.V(t),opts.lambda); end
        %results.Zp(t,:) = post;
        
        % loop over EM iterations
        for iter = 1:nIter
            V = X(t,:)*W;                               % reward prediction
            p = normpdf(r(t),V,sqrt(opts.sr));
            post = bsxfun(@times,L,p);    % unnormalized posterior with reward
            post = bsxfun(@rdivide,post,sum(post,2));
            results.Zp(t,:) = mean(post);
            x = repmat(X(t,:)',1,opts.K);
            for m = 1:opts.M
                rpe = repmat((r(t)-V).*post(m,:),D,1);           % reward prediction error
                W = W + opts.eta.*x.*rpe/opts.M;                   % weight update
                if psi(t)==1
                    W = W.*(1-repmat(post,D,1));
                end
            end
            results.W{t,iter} = W;
            results.P{t,iter} = post;
        end
        
        % cluster assignment
        k = fastrandsample(mean(post),opts.M);
        for m = 1:opts.M; Z(t,k(m),m)=1; end
        
    end
    
    %store results
    results.Z = Z;
    results.S = S;