I use DQN from sb3 to train a model. I want to train 2 agents that play against each other alternately. The problem is, as soon as I call model.learn(total_timesteps=N), which is the central method to train the model, my environment changes. I use an action mask and as soon as model.learn(total_timesteps=N) is called, this mask is reset to the initial state and only invalid moves occur. There is also no real learning process.
The movement logic works and the environment is also implemented correctly. Masking invalid actions with -np.inf also works and a valid action is selected by the agent. After a few timesteps or episodes, however, the q-values have to be adjusted so that the model learns. This is then done with the learn() call. This needs an environment as input and executes reset()/step() internally. That's how I read it and that could explain why my environment is always reset to the beginning after the learn() call and generally just doesn't work properly. Does anyone know how to implement DQN with sb3 correctly and how to implement this model.learn(total_timsteps=N) call correctly?
Here is the relevant part of the code:
while not done and turn_count < MAX_TURNS_PER_EPISODE:
current_player = info.get("current_player")
action_mask = np.array(info.get("action_mask", np.ones(env.action_space.n)), dtype=np.int32)
# Choose the right agent
current_agent = roman_agent if current_player == "R" else german_agent
# Get all valid moves as coordinates
valid_moves = env.unwrapped.get_all_valid_moves(current_player)
# Select action based on masking
action = masked_predict(current_agent, obs, action_mask)
# Perform step in the environment
obs, reward, done, truncated, info = env.step(action)
obs = np.array(obs, dtype=np.float32)
episode_rewards.append(reward)
# Save detailed logs
log_move(turn_count, current_player, action_mask, valid_moves)
# If the episode was cancelled
if truncated:
break
turn_count += 1
print(f"Start training by episode {episode + 1}")
roman_agent.learn(total_timesteps=500)
german_agent.learn(total_timesteps=500)**
If it is requested, I can also provide the console logs