Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions pygam/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import scipy as sp


def derivative(n, coef, derivative=2, periodic=False):
def derivative(n, coef, derivative=2, periodic=False, square_form=True):
"""
Builds a penalty matrix for P-Splines with continuous features.
Penalizes the squared differences between basis coefficients.
Expand All @@ -24,29 +24,42 @@ def derivative(n, coef, derivative=2, periodic=False):
derivative is 1, we penalize 1st order derivatives,
derivative is 2, we penalize 2nd order derivatives, etc

periodic: bool, defualt: False
whether to make the penalty n-periodic
effectively makes the matrix circulant

square_form: bool, default: True
whether to return the square form, which is typically used.

Returns
-------
penalty matrix : sparse csc matrix of shape (n,n)
"""
if n == 1:
# no derivative for constant functions
return sp.sparse.csc_matrix(0.0)
D = sparse_diff(
sp.sparse.identity(n + 2 * derivative * periodic).tocsc(), n=derivative
).tolil()
D = (
sparse_diff(
sp.sparse.identity(n + 2 * derivative * periodic).tocsc(), n=derivative
)
.tolil()
.T
)

if periodic:
# wrap penalty
cols = D[:, :derivative]
D[:, -2 * derivative : -derivative] += cols * (-1) ** derivative

# do symmetric operation on lower half of matrix
n_rows = int((n + 2 * derivative) / 2)
D[-n_rows:] = D[:n_rows][::-1, ::-1]
leftcols = D[:, :derivative]
righttcols = D[:, -derivative:]
D[:, -derivative - 1 : -1] += leftcols[:, ::-1]
D[:, 1 : 1 + derivative :] += righttcols[:, ::-1]

# keep only the center of the augmented matrix
D = D[derivative:-derivative, derivative:-derivative]
return D.dot(D.T).tocsc()
D = D[derivative - 1 : -derivative + 1, derivative:-derivative]

if not square_form:
return D

return D.T.dot(D).tocsc()


def periodic(n, coef, derivative=2, _penalty=derivative):
Expand Down
Loading