diff --git a/improved_diffusion/gaussian_diffusion.py b/improved_diffusion/gaussian_diffusion.py index 403d474f3b..ae634cdece 100644 --- a/improved_diffusion/gaussian_diffusion.py +++ b/improved_diffusion/gaussian_diffusion.py @@ -38,6 +38,18 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) + elif schedule_name == "constant": + scale = 1000 / num_diffusion_timesteps + beta_end = scale * 0.02 + return beta_end * np.ones( + num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "sigmoid": + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + betas = np.linspace(-6, 6, num_diffusion_timesteps) + return (1/(np.exp(betas) + 1)) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"unknown beta schedule: {schedule_name}")