-
Notifications
You must be signed in to change notification settings - Fork 224
Open
Description
Hi, thank you for the great package.
I am working with the transition module of mp_srlds and came across the code:
ssm/ssm/extensions/mp_srslds/transitions_ext.py
Lines 82 to 97 in 6c856ad
| Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1 | |
| for k1 in range(self.K): | |
| for k2 in range(self.K): | |
| vtilde = vtildes[:,k1,k2][:,None] # SWAP? | |
| #Sticky terms | |
| if k1==k2: | |
| Rv = vtilde@self.Ss[k2:k2+1,:] | |
| hess += Ez[k1,k2] * \ | |
| ( np.einsum('tn, ni, nj ->tij', -vtilde, self.Ss[k2:k2+1,:], self.Ss[k2:k2+1,:]) \ | |
| + np.einsum('ti, tj -> tij', Rv, Rv)) | |
| #Switching terms | |
| else: | |
| Rv = vtilde@self.Rs[k2:k2+1,:] | |
| hess += Ez[k1,k2] * \ | |
| ( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs[k2:k2+1,:], self.Rs[k2:k2+1,:]) \ | |
| + np.einsum('ti, tj -> tij', Rv, Rv)) |
where on line 89
Ez was indexed by k1 and k2. However on line 82:Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1and after checking the dimensions of
expected_joints:Lines 186 to 198 in 6c856ad
| # Compute E[z_t, z_{t+1}] for t = 1, ..., T-1 | |
| # Note that this is an array of size T*K*K, which can be quite large. | |
| # To be a bit more frugal with memory, first check if the given log_Ps | |
| # are TxKxK. If so, instantiate the full expected joints as well, since | |
| # we will need them for the M-step. However, if log_Ps is 1xKxK then we | |
| # know that the transition matrix is stationary, and all we need for the | |
| # M-step is the sum of the expected joints. | |
| stationary = (Ps.shape[0] == 1) | |
| if not stationary: | |
| expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps | |
| expected_joints -= expected_joints.max((1,2))[:,None, None] | |
| expected_joints = np.exp(expected_joints) | |
| expected_joints /= expected_joints.sum((1,2))[:,None,None] |
I believe it should have dimensions
(T-1, K, K). As a result Ez would have dimensions (T-1, K), but as shown above the time dimension was actually indexed using k1, which is a bit confusing to me.
Could you clarify if this behavior is intentional, or if there might be a mistake in how Ez is used? I may be missing something here, so I’d appreciate your insight. Thanks for your time and support!
Metadata
Metadata
Assignees
Labels
No labels