Source code for pypots.nn.modules.timellm.backbone

""" """

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import os

import torch
import torch.nn as nn
from transformers import (
    LlamaModel,
    LlamaTokenizer,
    GPT2Model,
    GPT2Tokenizer,
    BertModel,
    BertTokenizer,
)

from .layers import ReprogrammingLayer
from ..patchtst.layers import PatchEmbedding, FlattenHead
from ..revin import RevIN

SUPPORTED_LLM = [
    "LLaMA",
    "GPT2",
    "BERT",
]

SUPPORTED_TASKS = [
    "long_term_forecast",
    "short_term_forecast",
    "imputation",
    "classification",
    "clustering",
]


[docs] class BackboneTimeLLM(nn.Module): def __init__( self, n_steps, n_features, n_pred_steps, n_layers, patch_size, patch_stride, d_model, d_ffn, d_llm, n_heads, llm_model_type, dropout, domain_prompt_content: str, task_name: str, ): super().__init__() self.n_features = n_features self.n_pred_steps = n_pred_steps self.n_steps = n_steps self.d_ffn = d_ffn self.d_llm = d_llm self.patch_size = patch_size self.patch_stride = patch_stride self.task_name = task_name self.top_k = 5 # fixed value, the same as the official implementation assert n_steps > patch_size, "The length of the time series must be greater than the patch length." assert llm_model_type in SUPPORTED_LLM, f"The LLM model type must be one of {SUPPORTED_LLM}." assert task_name in SUPPORTED_TASKS, f"The task name must be one of {SUPPORTED_TASKS}." if llm_model_type == "LLaMA": self.llm_model = LlamaModel.from_pretrained( "huggyllama/llama-7b", num_hidden_layers=n_layers, output_attentions=True, output_hidden_states=True, # load_in_4bit=True ) self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b") elif llm_model_type == "GPT2": self.llm_model = GPT2Model.from_pretrained( "openai-community/gpt2", num_hidden_layers=n_layers, output_attentions=True, output_hidden_states=True, ) self.tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") elif llm_model_type == "BERT": self.llm_model = BertModel.from_pretrained( "google-bert/bert-base-uncased", num_hidden_layers=n_layers, output_attentions=True, output_hidden_states=True, ) self.tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") else: raise Exception("LLM model is not defined") if self.tokenizer.eos_token: self.tokenizer.pad_token = self.tokenizer.eos_token else: pad_token = "[PAD]" self.tokenizer.add_special_tokens({"pad_token": pad_token}) self.tokenizer.pad_token = pad_token # freeze the LLM model for param in self.llm_model.parameters(): param.requires_grad = False self.patch_embedding = PatchEmbedding( d_model, patch_size, patch_stride, patch_stride, dropout, False, ) self.domain_prompt_content = domain_prompt_content self.word_embeddings = self.llm_model.get_input_embeddings().weight self.vocab_size = self.word_embeddings.shape[0] self.n_tokens = 1000 self.mapping_layer = nn.Linear(self.vocab_size, self.n_tokens) self.reprogramming_layer = ReprogrammingLayer(d_model, n_heads, self.d_ffn, self.d_llm) self.n_patches = int((n_steps - self.patch_size) / self.patch_stride + 2) self.revin_layer = RevIN(n_features, affine=False) if self.task_name in ["long_term_forecast", "short_term_forecast", "imputation"]: self.output_projection = FlattenHead( d_ffn * self.n_patches, n_pred_steps, n_features, head_dropout=dropout, ) else: raise NotImplementedError def calc_lags(self, x_enc): q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1) k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1) res = q_fft * torch.conj(k_fft) corr = torch.fft.irfft(res, dim=-1) mean_value = torch.mean(corr, dim=1) _, lags = torch.topk(mean_value, self.top_k, dim=-1) return lags
[docs] def forward(self, x_enc, missing_mask=None): x_enc = self.revin_layer(x_enc, mode="norm") B, T, N = x_enc.size() x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) if missing_mask is not None: missing_mask = missing_mask.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) min_values = torch.min(x_enc, dim=1)[0] max_values = torch.max(x_enc, dim=1)[0] medians = torch.median(x_enc, dim=1).values lags = self.calc_lags(x_enc) trends = x_enc.diff(dim=1).sum(dim=1) prompt = [] for b in range(x_enc.shape[0]): min_values_str = str(min_values[b].tolist()[0]) max_values_str = str(max_values[b].tolist()[0]) median_values_str = str(medians[b].tolist()[0]) lags_values_str = str(lags[b].tolist()) if self.task_name == "long_term_forecast" or self.task_name == "short_term_forecast": prompt_ = ( f"<|start_prompt|>Dataset description: {self.domain_prompt_content}" "Task description: " f"forecast the next {str(self.n_pred_steps)} steps given " f"the previous {str(self.n_steps)} steps information; " "Input statistics: " f"min value {min_values_str}, " f"max value {max_values_str}, " f"median value {median_values_str}, " f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, " f"top 5 lags are : {lags_values_str}<|<end_prompt>|>" ) elif self.task_name == "imputation": prompt_ = ( f"<|start_prompt|>Dataset description: {self.domain_prompt_content}" "Task description: " f"given the observed information, " f"impute the missing values that indicated as 0 in f{missing_mask[b].flatten()}; " "Input statistics: " f"min value {min_values_str}, " f"max value {max_values_str}, " f"median value {median_values_str}, " f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, " f"top 5 lags are : {lags_values_str}<|<end_prompt>|>" ) else: raise NotImplementedError prompt.append(prompt_) x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous() prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (bz, prompt_token, dim) source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) x_enc = x_enc.permute(0, 2, 1).contiguous() if os.getenv("ENABLE_AMP", False): enc_out = self.patch_embedding(x_enc.to(torch.bfloat16)) else: enc_out = self.patch_embedding(x_enc) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1) dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state dec_out = dec_out[:, :, : self.d_ffn] dec_out = torch.reshape(dec_out, (-1, self.n_features, dec_out.shape[-2], dec_out.shape[-1])) dec_out = dec_out.permute(0, 1, 3, 2).contiguous() if self.task_name in ["long_term_forecast", "short_term_forecast", "imputation"]: dec_out = self.output_projection(dec_out[:, :, :, -self.n_patches :]) else: raise NotImplementedError dec_out = dec_out.permute(0, 2, 1).contiguous() dec_out = self.revin_layer(dec_out, mode="denorm") return dec_out