I have 2 matrices P and V and when I take their dot product with triton I get results that are inconsistent with pytorch.
The P and V matrices are as follows. P is basically the softmax which is why it is mostly 0s except the final column, the result of the dot product should be the final row of V.
P = torch.zeros((32,32), device = 'cuda', dtype = torch.float32)
P[:,-1] = 1
V = torch.arange(32*64, 64 * 64, device = 'cuda', dtype = torch.float32). reshape(32, 64)
On calling tl.dot(P, V) both are loading correctly (or so they appear to me), but the output is
[4032., 4032., 4034., 4034., 4036., 4036., 4038., 4038., 4040., 4040.,
4042., 4042., 4044., 4044., 4046., 4046., 4048., 4048., 4050., 4050.,
4052., 4052., 4054., 4054., 4056., 4056., 4058., 4058., 4060., 4060.,
4062., 4062., 4064., 4064., 4066., 4066., 4068., 4068., 4070., 4070.,
4072., 4072., 4074., 4074., 4076., 4076., 4078., 4078., 4080., 4080.,
4082., 4082., 4084., 4084., 4086., 4086., 4088., 4088., 4090., 4090.,
4092., 4092., 4094., 4094.]
instead of what I get from torch.matmul which is
[4032., 4033., 4034., 4035., 4036., 4037., 4038., 4039., 4040., 4041.,
4042., 4043., 4044., 4045., 4046., 4047., 4048., 4049., 4050., 4051.,
4052., 4053., 4054., 4055., 4056., 4057., 4058., 4059., 4060., 4061.,
4062., 4063., 4064., 4065., 4066., 4067., 4068., 4069., 4070., 4071.,
4072., 4073., 4074., 4075., 4076., 4077., 4078., 4079., 4080., 4081.,
4082., 4083., 4084., 4085., 4086., 4087., 4088., 4089., 4090., 4091.,
4092., 4093., 4094., 4095.]
The following is the code I'm testing this out in
import triton
import triton.language as tl
import torch
torch.cuda.is_available()
torch.set_printoptions(profile="full")
@triton.jit
def test_kernel(x_ptr,y_ptr,output_ptr,
M, K, N,
stride_xm, stride_xk,
stride_yk, stride_yn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr):
pid_m = tl.program_id(axis = 0) * BLOCK_SIZE_M
pid_n = tl.program_id(axis = 1) * BLOCK_SIZE_N
x_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] * stride_xm + (pid_n + tl.arange(0, BLOCK_SIZE_K))[None,:]*stride_xk
y_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_K))[:,None] * stride_yk + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None,:]*stride_yn
x = tl.load(x_ptr)
y = tl.load(y_ptr)
output_offset = (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] *stride_om + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None, :] *stride_on
tl.store(output_ptr + output_offset, tl.dot(x,y))
def helper(x: torch.Tensor, y: torch.Tensor):
M , K = x.shape
K1, N = y.shape
assert K == K1
output = torch.empty((M, N), device = 'cuda', dtype = torch.float32)
assert x.is_cuda and y.is_cuda and output.is_cuda
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']),)
test_kernel[grid](x, y, output,
M, K, N,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_N = 64,
BLOCK_SIZE_K = 32,
BLOCK_SIZE_M = 32,
)
return output
The strangest thing is when I define V = torch.arange(0, 32*64, device = 'cuda', dtype = torch.float32). reshape(32, 64)
it works as expected. Is there something with pointer operations that I'm missing here?
I have 2 matrices P and V and when I take their dot product with triton I get results that are inconsistent with pytorch.
The P and V matrices are as follows. P is basically the softmax which is why it is mostly 0s except the final column, the result of the dot product should be the final row of V.
P = torch.zeros((32,32), device = 'cuda', dtype = torch.float32)
P[:,-1] = 1
V = torch.arange(32*64, 64 * 64, device = 'cuda', dtype = torch.float32). reshape(32, 64)
On calling tl.dot(P, V) both are loading correctly (or so they appear to me), but the output is
[4032., 4032., 4034., 4034., 4036., 4036., 4038., 4038., 4040., 4040.,
4042., 4042., 4044., 4044., 4046., 4046., 4048., 4048., 4050., 4050.,
4052., 4052., 4054., 4054., 4056., 4056., 4058., 4058., 4060., 4060.,
4062., 4062., 4064., 4064., 4066., 4066., 4068., 4068., 4070., 4070.,
4072., 4072., 4074., 4074., 4076., 4076., 4078., 4078., 4080., 4080.,
4082., 4082., 4084., 4084., 4086., 4086., 4088., 4088., 4090., 4090.,
4092., 4092., 4094., 4094.]
instead of what I get from torch.matmul which is
[4032., 4033., 4034., 4035., 4036., 4037., 4038., 4039., 4040., 4041.,
4042., 4043., 4044., 4045., 4046., 4047., 4048., 4049., 4050., 4051.,
4052., 4053., 4054., 4055., 4056., 4057., 4058., 4059., 4060., 4061.,
4062., 4063., 4064., 4065., 4066., 4067., 4068., 4069., 4070., 4071.,
4072., 4073., 4074., 4075., 4076., 4077., 4078., 4079., 4080., 4081.,
4082., 4083., 4084., 4085., 4086., 4087., 4088., 4089., 4090., 4091.,
4092., 4093., 4094., 4095.]
The following is the code I'm testing this out in
import triton
import triton.language as tl
import torch
torch.cuda.is_available()
torch.set_printoptions(profile="full")
@triton.jit
def test_kernel(x_ptr,y_ptr,output_ptr,
M, K, N,
stride_xm, stride_xk,
stride_yk, stride_yn,
stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr):
pid_m = tl.program_id(axis = 0) * BLOCK_SIZE_M
pid_n = tl.program_id(axis = 1) * BLOCK_SIZE_N
x_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] * stride_xm + (pid_n + tl.arange(0, BLOCK_SIZE_K))[None,:]*stride_xk
y_ptr += (pid_m + tl.arange(0, BLOCK_SIZE_K))[:,None] * stride_yk + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None,:]*stride_yn
x = tl.load(x_ptr)
y = tl.load(y_ptr)
output_offset = (pid_m + tl.arange(0, BLOCK_SIZE_M))[:,None] *stride_om + (pid_n + tl.arange(0, BLOCK_SIZE_N))[None, :] *stride_on
tl.store(output_ptr + output_offset, tl.dot(x,y))
def helper(x: torch.Tensor, y: torch.Tensor):
M , K = x.shape
K1, N = y.shape
assert K == K1
output = torch.empty((M, N), device = 'cuda', dtype = torch.float32)
assert x.is_cuda and y.is_cuda and output.is_cuda
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']),)
test_kernel[grid](x, y, output,
M, K, N,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_N = 64,
BLOCK_SIZE_K = 32,
BLOCK_SIZE_M = 32,
)
return output
The strangest thing is when I define V = torch.arange(0, 32*64, device = 'cuda', dtype = torch.float32). reshape(32, 64)
it works as expected. Is there something with pointer operations that I'm missing here?
- The way to signify the problem is solved is to accept the answer – talonmies Commented Apr 6 at 11:51
1 Answer
Reset to default 1The problem has been solved (thanks to dai on discord). The issue is input_precision
is tf32 by default for dot product, which has 10bits mantissa - leading to trailing digit loss. The problem was very pronounced with V = torch.arange(4096, 4096 + 2048, device = 'cuda', dtype = torch.float32)
, where the output was [6080., 6080., 6080., 6080., 6084., 6084., 6084., 6084., 6088.,...]
. Switching to "ieee" input precision tl.dot(x,y, input_precision = "ieee")
solved the issue.