I have an LSTM model that receives 5 inputs to predict 3 outputs:
import torch
import torch.nn as nn
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
None
I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.
One solution I have tried is using separate LSTMs for each output:
class LstmModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CustomLSTMModel, self).__init__()
self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc1 = nn.Linear(hidden_size, output_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Assume x is of shape (batch_size, seq_length, input_size)
# Split inputs
input1, input2, input3, input4, input5 = x.split(1, dim=2)
# Mask inputs for each output
# For output1, exclude input2
input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
# For output2, exclude input3
input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
# For output3, exclude input4
input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)
# Process through LSTM
_, (hn1, _) = self.lstm1(input1_for_output1)
output1 = self.fc1(hn1[-1])
_, (hn2, _) = self.lstm2(input2_for_output2)
output2 = self.fc2(hn2[-1])
_, (hn3, _) = self.lstm3(input3_for_output3)
output3 = self.fc2(hn3[-1])
return output1, output2, output3
The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?