Source code for pypots.nn.modules.csdi.backbone

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import numpy as np
import torch
import torch.nn as nn

from .layers import CsdiDiffusionModel


[docs] class BackboneCSDI(nn.Module): def __init__( self, n_layers, n_heads, n_channels, d_target, d_time_embedding, d_feature_embedding, d_diffusion_embedding, is_unconditional, n_diffusion_steps, schedule, beta_start, beta_end, ): super().__init__() self.d_target = d_target self.d_time_embedding = d_time_embedding self.d_feature_embedding = d_feature_embedding self.is_unconditional = is_unconditional self.n_channels = n_channels self.n_diffusion_steps = n_diffusion_steps d_side = d_time_embedding + d_feature_embedding if self.is_unconditional: d_input = 1 else: d_side += 1 # for conditional mask d_input = 2 self.diff_model = CsdiDiffusionModel( n_diffusion_steps, d_diffusion_embedding, d_input, d_side, n_channels, n_heads, n_layers, ) # parameters for diffusion models if schedule == "quad": self.beta = np.linspace(beta_start**0.5, beta_end**0.5, self.n_diffusion_steps) ** 2 elif schedule == "linear": self.beta = np.linspace(beta_start, beta_end, self.n_diffusion_steps) else: raise ValueError(f"The argument schedule should be 'quad' or 'linear', but got {schedule}") self.alpha_hat = 1 - self.beta self.alpha = np.cumprod(self.alpha_hat) self.register_buffer("alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1)) def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): if self.is_unconditional: total_input = noisy_data.unsqueeze(1) # (B,1,K,L) else: cond_obs = (cond_mask * observed_data).unsqueeze(1) noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) return total_input def calc_loss_valid(self, observed_data, cond_mask, indicating_mask, side_info): loss_sum = 0 for t in range(self.n_diffusion_steps): # calculate loss for all t loss = self.calc_loss(observed_data, cond_mask, indicating_mask, side_info, set_t=t) loss_sum += loss.detach() return loss_sum / self.n_diffusion_steps def calc_loss(self, observed_data, cond_mask, indicating_mask, side_info, set_t=-1): B, K, L = observed_data.shape device = observed_data.device if self.training != 1: # for validation t = (torch.ones(B) * set_t).long().to(device) else: t = torch.randint(0, self.n_diffusion_steps, [B]).to(device) current_alpha = self.alpha_torch[t] # (B,1,1) noise = torch.randn_like(observed_data) noisy_data = (current_alpha**0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) predicted = self.diff_model(total_input, side_info, t) # (B,K,L) target_mask = indicating_mask residual = (noise - predicted) * target_mask num_eval = target_mask.sum() loss = (residual**2).sum() / (num_eval if num_eval > 0 else 1) return loss
[docs] def forward(self, observed_data, cond_mask, side_info, n_sampling_times): B, K, L = observed_data.shape device = observed_data.device imputed_samples = torch.zeros(B, n_sampling_times, K, L).to(device) for i in range(n_sampling_times): # generate noisy observation for unconditional model if self.is_unconditional: noisy_obs = observed_data noisy_cond_history = [] for t in range(self.n_diffusion_steps): noise = torch.randn_like(noisy_obs) noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise noisy_cond_history.append(noisy_obs * cond_mask) current_sample = torch.randn_like(observed_data) for t in range(self.n_diffusion_steps - 1, -1, -1): if self.is_unconditional: diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample diff_input = diff_input.unsqueeze(1) # (B,1,K,L) else: cond_obs = (cond_mask * observed_data).unsqueeze(1) noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) predicted = self.diff_model(diff_input, side_info, torch.tensor([t]).to(device)) coeff1 = 1 / self.alpha_hat[t] ** 0.5 coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 current_sample = coeff1 * (current_sample - coeff2 * predicted) if t > 0: noise = torch.randn_like(current_sample) sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5 current_sample += sigma * noise imputed_samples[:, i] = current_sample.detach() return imputed_samples