Contains 3 .zip files with the trained model ensembles and the torch.nn.Module class:
- 01Layers16Cells_Ensemble.zip
- 02Layers32Cells_Ensemble.zip
- 04Layers64Cells_Ensemble.zip
- recurrent_neuralnetworks.py
Loading a checkpoint manually:
import torch
from recurrent_neuralnetworks import LSTMModelWithTeacherForcing
model = LSTMModelWithTeacherForcing(
num_features=12, # 11 inputs + 1 temperature feedback
num_hidden=32, # match the run (16 / 32 / 64)
num_layers=2, # match the run (1 / 2 / 4)
num_labels=1,
)
state = torch.load("path/to/model_best_....pth", map_location="cpu")
model.load_state_dict(state)
model.eval()