Skip to content

Commit

Permalink
Add DPMSolverScheduler trait
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 8, 2023
1 parent dea4d4a commit 83b28b3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 39 deletions.
45 changes: 45 additions & 0 deletions src/schedulers/dpmsolver.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use tch::Tensor;

use crate::schedulers::BetaSchedule;
use crate::schedulers::PredictionType;

Expand Down Expand Up @@ -65,3 +67,46 @@ impl Default for DPMSolverSchedulerConfig {
}
}
}

pub trait DPMSolverScheduler {
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self;
fn convert_model_output(
&self,
model_output: &Tensor,
timestep: usize,
sample: &Tensor,
) -> Tensor;

fn first_order_update(
&self,
model_output: Tensor,
timestep: usize,
prev_timestep: usize,
sample: &Tensor,
) -> Tensor;

fn second_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 2],
prev_timestep: usize,
sample: &Tensor,
) -> Tensor;

fn third_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 3],
prev_timestep: usize,
sample: &Tensor,
) -> Tensor;

fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor;

fn timesteps(&self) -> &[usize];
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;


fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
fn init_noise_sigma(&self) -> f64;
}
46 changes: 23 additions & 23 deletions src/schedulers/dpmsolver_multistep.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
use super::{
betas_for_alpha_bar,
dpmsolver::{
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
},
BetaSchedule, PredictionType,
};
use tch::{kind, Kind, Tensor};

pub struct DPMSolverMultistepScheduler {
Expand All @@ -15,8 +21,8 @@ pub struct DPMSolverMultistepScheduler {
pub config: DPMSolverSchedulerConfig,
}

impl DPMSolverMultistepScheduler {
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
impl DPMSolverScheduler for DPMSolverMultistepScheduler {
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
Expand Down Expand Up @@ -117,7 +123,7 @@ impl DPMSolverMultistepScheduler {

/// One step for the first-order DPM-Solver (equivalent to DDIM).
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
fn dpm_solver_first_order_update(
fn first_order_update(
&self,
model_output: Tensor,
timestep: usize,
Expand All @@ -139,7 +145,7 @@ impl DPMSolverMultistepScheduler {
}

/// One step for the second-order multistep DPM-Solver.
fn multistep_dpm_solver_second_order_update(
fn second_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 2],
Expand Down Expand Up @@ -192,7 +198,7 @@ impl DPMSolverMultistepScheduler {
}

/// One step for the third-order multistep DPM-Solver
fn multistep_dpm_solver_third_order_update(
fn third_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 3],
Expand Down Expand Up @@ -237,11 +243,11 @@ impl DPMSolverMultistepScheduler {
}
}

pub fn timesteps(&self) -> &[usize] {
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}

pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457
let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap();

Expand All @@ -266,24 +272,14 @@ impl DPMSolverMultistepScheduler {
|| self.lower_order_nums < 1
|| lower_order_final
{
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
self.first_order_update(model_output, timestep, prev_timestep, sample)
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
let timestep_list = [self.timesteps[step_index - 1], timestep];
self.multistep_dpm_solver_second_order_update(
&self.model_outputs,
timestep_list,
prev_timestep,
sample,
)
self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
} else {
let timestep_list =
[self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
self.multistep_dpm_solver_third_order_update(
&self.model_outputs,
timestep_list,
prev_timestep,
sample,
)
self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
};

if self.lower_order_nums < self.config.solver_order {
Expand All @@ -293,12 +289,16 @@ impl DPMSolverMultistepScheduler {
prev_sample
}

pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
}

pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}

fn scale_model_input(&self, _sample: Tensor, _timestep: usize) -> Tensor {
todo!()
}
}
34 changes: 18 additions & 16 deletions src/schedulers/dpmsolver_singlestep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::iter::repeat;

use super::{
betas_for_alpha_bar,
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
dpmsolver::{
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
},
BetaSchedule, PredictionType,
};
use tch::{kind, Kind, Tensor};
Expand All @@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler {
pub config: DPMSolverSchedulerConfig,
}

impl DPMSolverSinglestepScheduler {
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
Expand Down Expand Up @@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler {
/// * `timestep` - current discrete timestep in the diffusion chain
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
/// * `sample` - current instance of sample being created by diffusion process
fn dpm_solver_first_order_update(
fn first_order_update(
&self,
model_output: &Tensor,
model_output: Tensor,
timestep: usize,
prev_timestep: usize,
sample: &Tensor,
Expand Down Expand Up @@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler {
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
/// * `sample` - current instance of sample being created by diffusion process
fn singlestep_dpm_solver_second_order_update(
fn second_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 2],
Expand Down Expand Up @@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler {
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
/// * `sample` - current instance of sample being created by diffusion process
fn singlestep_dpm_solver_third_order_update(
fn third_order_update(
&self,
model_output_list: &Vec<Tensor>,
timestep_list: [usize; 3],
Expand Down Expand Up @@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler {
}
}

pub fn timesteps(&self) -> &[usize] {
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}

/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
sample
}

Expand All @@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler {
/// * `model_output` - direct output from learned diffusion model
/// * `timestep` - current discrete timestep in the diffusion chain
/// * `sample` - current instance of sample being created by diffusion process
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap();

Expand All @@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler {
};

match order {
1 => self.dpm_solver_first_order_update(
&self.model_outputs[self.model_outputs.len() - 1],
1 => self.first_order_update(
model_output,
timestep,
prev_timestep,
&self.sample.as_ref().unwrap(),
),
2 => self.singlestep_dpm_solver_second_order_update(
2 => self.second_order_update(
&self.model_outputs,
[self.timesteps[step_index - 1], self.timesteps[step_index]],
prev_timestep,
&self.sample.as_ref().unwrap(),
),
3 => self.singlestep_dpm_solver_third_order_update(
3 => self.third_order_update(
&self.model_outputs,
[
self.timesteps[step_index - 2],
Expand All @@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler {
}
}

pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
}

pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}
Expand Down

0 comments on commit 83b28b3

Please sign in to comment.