
python - How to prevent certain input from impacting certain output of neural networks in pytorch? - Stack Overflow


I have an LSTM model that receives 5 inputs to predict 3 outputs:

import torch
import torch.nn as nn

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):

I want to prevent certain input from having any impact on a certain output. Let's say, the first input should not have any effect on the prediction of the second output. In other words, the second prediction should not be a function of the first input.

One solution I have tried is using separate LSTMs for each output:

class LstmModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomLSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm2 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm3 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, output_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Assume x is of shape (batch_size, seq_length, input_size)
        # Split inputs
        input1, input2, input3, input4, input5 = x.split(1, dim=2)

        # Mask inputs for each output
        # For output1, exclude input2
        input1_for_output1 = torch.cat((input1, input3, input4, input5), dim=2)
        # For output2, exclude input3
        input2_for_output2 = torch.cat((input1, input2, input4, input5), dim=2)
        # For output3, exclude input4
        input3_for_output3 = torch.cat((input1, input2, input3, input5), dim=2)

        # Process through LSTM
        _, (hn1, _) = self.lstm1(input1_for_output1)
        output1 = self.fc1(hn1[-1])

        _, (hn2, _) = self.lstm2(input2_for_output2)
        output2 = self.fc2(hn2[-1])

        _, (hn3, _) = self.lstm3(input3_for_output3)
        output3 = self.fc2(hn3[-1])

        return output1, output2, output3

The problem with this approach is that it takes at least 3 times longer to run the model (since I am running LSTM 3 times, 1 for each output). Is it possible to do what I want to achieve more efficiently, with one run?




  1. 暂无评论