Skip to content

Issue with hessian computation in vLEM for mp_srlds #171

@XiaoliangWang2001

Description

@XiaoliangWang2001

Hi, thank you for the great package.

I am working with the transition module of mp_srlds and came across the code:

Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
for k1 in range(self.K):
for k2 in range(self.K):
vtilde = vtildes[:,k1,k2][:,None] # SWAP?
#Sticky terms
if k1==k2:
Rv = vtilde@self.Ss[k2:k2+1,:]
hess += Ez[k1,k2] * \
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Ss[k2:k2+1,:], self.Ss[k2:k2+1,:]) \
+ np.einsum('ti, tj -> tij', Rv, Rv))
#Switching terms
else:
Rv = vtilde@self.Rs[k2:k2+1,:]
hess += Ez[k1,k2] * \
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs[k2:k2+1,:], self.Rs[k2:k2+1,:]) \
+ np.einsum('ti, tj -> tij', Rv, Rv))

where on line 89 Ez was indexed by k1 and k2. However on line 82:
Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
and after checking the dimensions of expected_joints:

ssm/ssm/messages.py

Lines 186 to 198 in 6c856ad

# Compute E[z_t, z_{t+1}] for t = 1, ..., T-1
# Note that this is an array of size T*K*K, which can be quite large.
# To be a bit more frugal with memory, first check if the given log_Ps
# are TxKxK. If so, instantiate the full expected joints as well, since
# we will need them for the M-step. However, if log_Ps is 1xKxK then we
# know that the transition matrix is stationary, and all we need for the
# M-step is the sum of the expected joints.
stationary = (Ps.shape[0] == 1)
if not stationary:
expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps
expected_joints -= expected_joints.max((1,2))[:,None, None]
expected_joints = np.exp(expected_joints)
expected_joints /= expected_joints.sum((1,2))[:,None,None]

I believe it should have dimensions (T-1, K, K). As a result Ez would have dimensions (T-1, K), but as shown above the time dimension was actually indexed using k1, which is a bit confusing to me.

Could you clarify if this behavior is intentional, or if there might be a mistake in how Ez is used? I may be missing something here, so I’d appreciate your insight. Thanks for your time and support!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions