I'm trying to build a super simple grid trading strategy using MassTransit Saga State Machine, and I'm running into a challenge when it comes to tracking the state for each individual grid level.
The idea of the grid strategy is as follows:
- Set upper and lower limits.
- Split range into a grid (either arithmetic or geometric spacing).
- Ignore the closest level to the current price.
- Place buy orders below and sell orders above the current price.
- When a sell order executes, place a new buy order one level lower.
- When a buy order executes, place a new order one level higher.
- Repeat indefinitely.
Problem
Currently, I'm modeling the entire grid strategy as a single state machine with just a CurrentState
for the entire grid (such as Initialized
, CrossingUp
, or CrossingDown
). However, what I really want is to have individual states for each grid level. Specifically:
- For each grid level, there should be a state (
Inactive
,Buy
, orSell
). - If a grid level is closest to the current price, the state should be
Inactive
. - If a grid level is below the current price, the state should be
Buy
. - If a grid level is above the current price, the state should be
Sell
.
Strategy Input Prameters
- Lower Limit: The lower boundary of the grid.
- Upper Limit: The upper boundary of the grid.
- Grid Count: The number of grid levels.
The grid can be generated using either arithmetic or geometric spacing. Below is a Python version of how the grid levels are calculated:
def get_grids(lower_limit, upper_limit, grid_count, tp="arth"):
if tp == "arth":
grids = np.linspace(lower_limit, upper_limit, grid_count + 1)
elif tp == "geom":
grids = np.geomspace(lower_limit, upper_limit, grid_count + 1)
else:
print("not right range type")
return grids
In C#, the arithmetic grid levels can be calculated as follows:
var step = (upperLimit - lowerLimit) / gridCount;
List<decimal> gridLevels = new List<decimal>();
for (var i = 0; i <= gridCount; i++)
{
var price = lowerLimit + step * i;
gridLevels.Add(price);
}
What I've tried (Minimal reproducible example)
var builder = Host.CreateApplicationBuilder(args);
builder.AddEventBus();
var host = builder.Build();
await host.StartAsync();
var bus = host.Services.GetRequiredService<IBus>();
var stateMachine = new GridStateMachine();
var lowerLimit = 25_000m;
var upperLimit = 35_000m;
var gridCount = 20;
var step = (upperLimit - lowerLimit) / gridCount;
List<decimal> gridLevels = [];
for (var i = 0; i <= gridCount; i++)
{
var price = lowerLimit + step * i;
gridLevels.Add(price);
}
await bus.Publish<GridInitialized>(new
{
CorrelationId = Guid.NewGuid(),
Price = 27000m,
GridLevels = gridLevels
});
Console.ReadLine();
await host.StopAsync();
public class GridState : SagaStateMachineInstance
{
public Guid CorrelationId { get; set; }
public string CurrentState { get; set; } = null!;
public decimal CurrentPrice { get; set; }
public List<decimal> GridLevels { get; set; } = [];
public DateTime LastUpdate { get; set; }
}
public class GridStateMachine : MassTransitStateMachine<GridState>
{
public GridStateMachine()
{
InstanceState(x => x.CurrentState);
Event(() => Initialized, e => e.CorrelateById(m => m.Message.CorrelationId));
Event(() => PriceCrossedUp, e => e.CorrelateById(m => m.Message.CorrelationId));
Event(() => PriceCrossedDown, e => e.CorrelateById(m => m.Message.CorrelationId));
Initially(
When(Initialized)
.IfElse(context => IsClosestLevel(context.Message.Price),
then => then.TransitionTo(Inactive),
orElse => orElse.IfElse(context => IsLowerLevel(context.Message.Price),
then => then.TransitionTo(Buy).Then(context => PlaceBuyOrder(context.Message)),
orElse2 => orElse2.TransitionTo(Sell).Then(context => PlaceSellOrder(context.Message))
)
)
);
}
public State Inactive { get; private set; } = null!;
public State Buy { get; private set; } = null!;
public State Sell { get; private set; } = null!;
public Event<GridInitialized> Initialized { get; private set; } = null!;
public Event<PriceCrossedUp> PriceCrossedUp { get; private set; } = null!;
public Event<PriceCrossedDown> PriceCrossedDown { get; private set; } = null!;
private bool IsClosestLevel(decimal price)
{
Console.WriteLine($"Is {price} the closest level?");
return false;
}
private bool IsLowerLevel(decimal price)
{
Console.WriteLine($"Is {price} lower level?");
return false;
}
private bool IsHigherLevel(decimal price)
{
Console.WriteLine($"Is {price} higher level?");
return false;
}
private void PlaceBuyOrder(GridInitialized request)
{
Console.WriteLine("Placing a buy order...");
}
private void PlaceSellOrder(GridInitialized request)
{
Console.WriteLine("Placing a sell order...");
}
}
public class GridInitialized
{
public Guid CorrelationId { get; set; }
public decimal Price { get; set; }
public List<decimal> GridLevels { get; set; } = [];
}
public class PriceCrossedUp
{
public Guid CorrelationId { get; set; }
}
public class PriceCrossedDown
{
public Guid CorrelationId { get; set; }
}
public static class ServiceCollectionExtensions
{
public static IHostApplicationBuilder AddEventBus(
this IHostApplicationBuilder builder,
Action<IBusRegistrationConfigurator>? massTransitConfiguration = null) =>
AddEventBus<IBus>(builder, massTransitConfiguration);
public static IHostApplicationBuilder AddEventBus<TBus>(
this IHostApplicationBuilder builder,
Action<IBusRegistrationConfigurator>? massTransitConfiguration = null)
where TBus : class, IBus
{
ArgumentNullException.ThrowIfNull(builder);
builder.Services.AddMassTransit<TBus>(x =>
{
x.SetKebabCaseEndpointNameFormatter();
x.SetInMemorySagaRepositoryProvider();
var entryAssembly = Assembly.GetEntryAssembly();
x.AddSagaStateMachines(entryAssembly);
x.AddSagas(entryAssembly);
x.AddActivities(entryAssembly);
massTransitConfiguration?.Invoke(x);
x.UsingRabbitMq((context, cfg) =>
{
cfg.Host("localhost", "/", h =>
{
h.Username("guest");
h.Password("guest");
});
cfg.ConfigureEndpoints(context);
});
});
return builder;
}
}
I’m struggling to figure out how to manage individual states for each grid level in MassTransit. Since a grid strategy can have many levels, I need to have different states for each grid level instead of one global state for the entire grid.
How can I do that?
Edit
Here is my current progress:
// Program.cs
var builder = Host.CreateApplicationBuilder(args);
builder.Services.AddSingleton<IGridLevelTracker, GridLevelTracker>();
builder.AddEventBus(options =>
{
options.AddConsumer<CurrentPriceUpdatedConsumer>();
});
var host = builder.Build();
await host.StartAsync();
var bus = host.Services.GetRequiredService<IBus>();
var gridLevelTracker = host.Services.GetRequiredService<IGridLevelTracker>();
// Strategy input parameters
var lowerLimit = 25_000m;
var upperLimit = 35_000m;
var gridCount = 20;
// Calculate arithmetic progression and create grid levels
var step = (upperLimit - lowerLimit) / gridCount;
for (var i = 0; i <= gridCount; i++)
{
var price = lowerLimit + step * i;
var gridLevelId = Guid.NewGuid();
// Register the grid level with our tracker
gridLevelTracker.RegisterGridLevel(gridLevelId, price);
// Initialize the grid level saga
await bus.Publish(new GridLevelInitialized
{
GridLevelId = gridLevelId,
Price = price
});
}
// Simulate a current price update
var simulatedCurrentPrice = 29_123m;
Console.WriteLine($"\nUpdating current price to {simulatedCurrentPrice}\n");
await bus.Publish(new CurrentPriceUpdated
{
CurrentPrice = simulatedCurrentPrice
});
Console.WriteLine("\nPress Enter to exit...");
Console.ReadLine();
await host.StopAsync();
// GridLevelState.cs
public class GridLevelState : SagaStateMachineInstance
{
public Guid CorrelationId { get; set; } // OrderId
public string CurrentState { get; set; } = null!; // Inactive, Buy, Sell
public decimal Price { get; set; }
public string? OrderId { get; set; }
}
public class GridLevelStateMachine : MassTransitStateMachine<GridLevelState>
{
public GridLevelStateMachine()
{
Event(() => GridLevelInitialized, x => x.CorrelateById(m => m.Message.GridLevelId));
Event(() => GridLevelShouldBeInactive, x => x.CorrelateById(m => m.Message.GridLevelId));
Event(() => GridLevelShouldBeBuy, x => x.CorrelateById(m => m.Message.GridLevelId));
Event(() => GridLevelShouldBeSell, x => x.CorrelateById(m => m.Message.GridLevelId));
InstanceState(x => x.CurrentState);
Initially(
When(GridLevelInitialized)
.Then(context =>
{
context.Saga.Price = context.Message.Price;
Console.WriteLine($"Grid level {context.Saga.CorrelationId} initialized at price {context.Saga.Price}");
})
.TransitionTo(Pending));
During(Pending,
When(GridLevelShouldBeInactive)
.Then(context =>
{
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} remains in Inactive state");
})
.TransitionTo(Inactive),
When(GridLevelShouldBeBuy)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Buy state");
context.Saga.OrderId = PlaceBuyOrder(context.Saga.Price);
})
.TransitionTo(Buy),
When(GridLevelShouldBeSell)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Sell state");
context.Saga.OrderId = PlaceSellOrder(context.Saga.Price);
})
.TransitionTo(Sell));
During(Buy,
When(GridLevelShouldBeInactive)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Inactive state");
CancelBuyOrder(context.Saga.Price);
})
.TransitionTo(Inactive),
When(GridLevelShouldBeSell)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Sell state");
CancelBuyOrder(context.Saga.Price);
context.Saga.OrderId = PlaceSellOrder(context.Saga.Price);
})
.TransitionTo(Sell));
During(Sell,
When(GridLevelShouldBeInactive)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Inactive state");
CancelSellOrder(context.Saga.Price);
})
.TransitionTo(Inactive),
When(GridLevelShouldBeBuy)
.Then(context => {
Console.WriteLine($"Grid level {context.Saga.CorrelationId} at price {context.Saga.Price} transitioning to Buy state");
CancelSellOrder(context.Saga.Price);
context.Saga.OrderId = PlaceBuyOrder(context.Saga.Price);
})
.TransitionTo(Buy));
SetCompletedWhenFinalized();
}
public State Pending { get; private set; } = null!;
public State Inactive { get; private set; } = null!;
public State Buy { get; private set; } = null!;
public State Sell { get; private set; } = null!;
public Event<GridLevelInitialized> GridLevelInitialized { get; private set; } = null!;
public Event<GridLevelShouldBeInactive> GridLevelShouldBeInactive { get; private set; } = null!;
public Event<GridLevelShouldBeBuy> GridLevelShouldBeBuy { get; private set; } = null!;
public Event<GridLevelShouldBeSell> GridLevelShouldBeSell { get; private set; } = null!;
private string PlaceBuyOrder(decimal price)
{
Console.WriteLine($"Placing buy order at {price}");
return Guid.NewGuid().ToString();
}
private string PlaceSellOrder(decimal price)
{
Console.WriteLine($"Placing sell order at {price}");
return Guid.NewGuid().ToString();
}
private void CancelBuyOrder(decimal price)
{
Console.WriteLine($"Cancelling buy order at {price}");
}
private void CancelSellOrder(decimal price)
{
Console.WriteLine($"Cancelling sell order at {price}");
}
}
// Messages.cs
public class GridLevelInitialized
{
public Guid GridLevelId { get; init; }
public decimal Price { get; init; }
}
public class GridLevelShouldBeInactive
{
public Guid GridLevelId { get; init; }
public decimal Price { get; init; }
}
public class GridLevelShouldBeBuy
{
public Guid GridLevelId { get; init; }
public decimal Price { get; init; }
}
public class GridLevelShouldBeSell
{
public Guid GridLevelId { get; init; }
public decimal Price { get; init; }
}
public class CurrentPriceUpdated
{
public decimal CurrentPrice { get; init; }
}
// GridLevelTracker.cs
public interface IGridLevelTracker
{
void RegisterGridLevel(Guid gridLevelId, decimal price);
List<Guid> GetAllGridLevelIds();
void UpdateCurrentPrice(decimal currentPrice);
(Guid? ClosestLevelId, List<Guid> LowerLevelIds, List<Guid> HigherLevelIds) GetLevelsRelativeToCurrentPrice();
}
public class GridLevelTracker : IGridLevelTracker
{
private readonly ConcurrentDictionary<Guid, decimal> _gridLevels = new();
private decimal _currentPrice;
public void RegisterGridLevel(Guid gridLevelId, decimal price)
{
_gridLevels.TryAdd(gridLevelId, price);
}
public List<Guid> GetAllGridLevelIds()
{
return _gridLevels.Keys.ToList();
}
public void UpdateCurrentPrice(decimal currentPrice)
{
_currentPrice = currentPrice;
}
public (Guid? ClosestLevelId, List<Guid> LowerLevelIds, List<Guid> HigherLevelIds) GetLevelsRelativeToCurrentPrice()
{
if (_gridLevels.IsEmpty)
return (null, [], []);
var levels = _gridLevels.ToList();
// Find closest level to current price
var closestLevel = levels
.OrderBy(x => Math.Abs(x.Value - _currentPrice))
.First();
// Find levels below current price, sorted by price in ascending order
var lowerLevels = levels
.Where(x => x.Value < _currentPrice && x.Key != closestLevel.Key)
.OrderBy(x => x.Value) // Sort by price ascending
.Select(x => x.Key)
.ToList();
// Find levels above current price, sorted by price in ascending order
var higherLevels = levels
.Where(x => x.Value > _currentPrice && x.Key != closestLevel.Key)
.OrderBy(x => x.Value) // Sort by price ascending
.Select(x => x.Key)
.ToList();
return (closestLevel.Key, lowerLevels, higherLevels);
}
}
// CurrentPriceUpdatedConsumer.cs
public class CurrentPriceUpdatedConsumer(IGridLevelTracker gridLevelTracker, IBus bus) : IConsumer<CurrentPriceUpdated>
{
public async Task Consume(ConsumeContext<CurrentPriceUpdated> context)
{
var currentPrice = context.Message.CurrentPrice;
// Update tracker with current price
gridLevelTracker.UpdateCurrentPrice(currentPrice);
// Get grid levels relative to current price
var (closestLevelId, lowerLevelIds, higherLevelIds) = gridLevelTracker.GetLevelsRelativeToCurrentPrice();
// Send events to all levels
foreach (var levelId in lowerLevelIds)
{
await bus.Publish(new GridLevelShouldBeBuy { GridLevelId = levelId });
}
if (closestLevelId.HasValue)
{
await bus.Publish(new GridLevelShouldBeInactive { GridLevelId = closestLevelId.Value });
}
foreach (var levelId in higherLevelIds)
{
await bus.Publish(new GridLevelShouldBeSell { GridLevelId = levelId });
}
}
}
// ServiceCollectionExtensions.cs (remains the same)
public static class ServiceCollectionExtensions
{
public static IHostApplicationBuilder AddEventBus(
this IHostApplicationBuilder builder,
Action<IBusRegistrationConfigurator>? massTransitConfiguration = null) =>
AddEventBus<IBus>(builder, massTransitConfiguration);
public static IHostApplicationBuilder AddEventBus<TBus>(
this IHostApplicationBuilder builder,
Action<IBusRegistrationConfigurator>? massTransitConfiguration = null)
where TBus : class, IBus
{
ArgumentNullException.ThrowIfNull(builder);
builder.Services.AddMassTransit<TBus>(x =>
{
x.SetKebabCaseEndpointNameFormatter();
x.SetInMemorySagaRepositoryProvider();
var entryAssembly = Assembly.GetEntryAssembly();
x.AddSagaStateMachines(entryAssembly);
x.AddSagas(entryAssembly);
x.AddActivities(entryAssembly);
massTransitConfiguration?.Invoke(x);
x.UsingRabbitMq((context, cfg) =>
{
cfg.Host("localhost", "/", h =>
{
h.Username("guest");
h.Password("guest");
});
cfg.ConfigureEndpoints(context);
});
});
return builder;
}
}