Instruction | General theme | Writemask | Optional special features |
---|---|---|---|
matfp |
z[j][i] ±= f(x[i], y[j]) |
9 bit X, 9 bit Y | Indexed X or Y, shuffle X, shuffle Y, positive selection |
Bit | Width | Meaning | Notes |
---|---|---|---|
10 | 22 | A64 reserved instruction | Must be 0x201000 >> 10 |
5 | 5 | Instruction | Must be 21 |
0 | 5 | 5-bit GPR index | See below for the meaning of the 64 bits in the GPR |
Bit | Width | Meaning | Notes |
---|---|---|---|
63 | 1 | Ignored | |
57 | 6 | Y enable value | Meaning dependent upon associated mode, see bit 23 |
54 | 3 | Must be zero | No-op otherwise |
53 | 1 | Indexed load (1 ) or regular load (0 ) |
|
(53=1) 52 | 1 | Ignored | |
(53=1) 49 | 3 | Register to index into | |
(53=1) 48 | 1 | Indices are 4 bits (1 ) or 2 bits (0 ) |
|
(53=1) 47 | 1 | Indexed load of Y (1 ) or of X (0 ) |
|
(53=0) 47 | 6 | ALU mode | |
46 | 1 | Ignored | |
42 | 4 | Lane width mode | |
41 | 1 | Ignored | |
38 | 3 | X enable mode | |
37 | 1 | Ignored | |
32 | 5 | X enable value | Meaning dependent upon associated mode |
31 | 1 | Ignored | |
29 | 2 | X shuffle | |
27 | 2 | Y shuffle | |
26 | 1 | Ignored | |
23 | 3 | Y enable mode | Not high bits of Z row |
20 | 3 | Z row | High bits ignored in some lane width modes |
19 | 1 | Ignored | |
10 | 9 | X offset (in bytes) | |
9 | 1 | Ignored | |
0 | 9 | Y offset (in bytes) |
ALU modes:
Floating-point operation | 47 | Notes |
---|---|---|
z + x*y |
0 |
|
z - x*y |
1 |
|
x <= 0 ? 0 : y |
4 |
Z input not used |
no-op | anything else |
Lane width modes:
X,Y | Z | 42 | Notes |
---|---|---|---|
bf16 | bf16 (one row from each two) | 0 |
M2 only |
bf16 | f32 (all rows, interleaved pairs) | 1 |
M2 only |
f16 | f32 (all rows, interleaved pairs) | 3 |
|
f32 | f32 (one row from each four) | 4 |
|
f64 | f64 (one row from each eight) | 7 |
|
f16 | f16 (one row from each two) | anything else |
X/Y enable modes:
Mode | Meaning of value (N) |
---|---|
0 |
Enable all lanes (0 ), or odd lanes only (1 ), or even lanes only (2 ), or enable all lanes but override the ALU operation to 0.0 (3 ) or enable all lanes but override their value to 0.0 (4 or 5 ) or no lanes enabled (anything else) |
1 |
Only enable lane #N |
2 |
Only enable the first N lanes, or all lanes when N is zero |
3 |
Only enable the last N lanes, or all lanes when N is zero |
4 |
Only enable the first N lanes (no lanes when N is zero) |
5 |
Only enable the last N lanes (no lanes when N is zero) |
6 |
No lanes enabled |
7 |
No lanes enabled |
Performs a fused-multiply-add (or other ALU operation) outer-product between an X vector, a Y vector, and a 2D grid of Z values, accumulating onto Z. All three of X and Y and Z have the same element type, either f16 or f32 or f64 (or bf16 on M2). Alternatively, when X and Y are both f16 (or bf16 on M2), Z can have type f32, in which case the entire 64x64 byte grid of Z is used, with even lanes of X going into even Z registers and odd lanes of X going into odd Z registers (see Mixed lane widths).
See matfp.c, and vecfp.c for the shared ALU. Note the code in test.c to set the DN bit of fpcr
.
A representative sample is:
void emulate_AMX_MATFP(amx_state* state, uint64_t operand) {
if ((operand >> 54) & 7) {
return;
}
operand &=~ (1ull << 37);
operand &=~ (1ull << 63);
int alumode = (operand & MATFP_INDEXED_LOAD) ? 0 : (operand >> 47) & 0x3f;
if (alumode == 2 || alumode == 3 || alumode >= 5) {
return;
}
uint32_t xybits, zbits, bf16 = 0;
switch ((operand >> 42) & 0xf) {
case 0: xybits = 16; if (AMX_VER >= AMX_VER_M2) { zbits = 16; bf16 = 1; } else { zbits = 16; } break;
case 1: xybits = 16; if (AMX_VER >= AMX_VER_M2) { zbits = 32; bf16 = 1; } else { zbits = 16; } break;
case 3: xybits = 16; zbits = 32; break;
case 4: xybits = 32; zbits = 32; break;
case 7: xybits = 64; zbits = 64; break;
default: xybits = 16; zbits = 16; break;
}
uint32_t xybytes = xybits / 8;
amx_reg x;
amx_reg y;
load_xy_reg(&x, state->x, (operand >> 10) & 0x1FF);
load_xy_reg(&y, state->y, operand & 0x1FF);
if (operand & MATFP_INDEXED_LOAD) {
uint32_t src_reg = (operand >> 49) & 7;
uint32_t ibits = (operand & MATFP_INDEXED_LOAD_4BIT) ? 4 : 2;
if (operand & MATFP_INDEXED_LOAD_Y) {
load_xy_reg_indexed(y.u8, state->y[src_reg].u8, ibits, xybits);
} else {
load_xy_reg_indexed(x.u8, state->x[src_reg].u8, ibits, xybits);
}
}
xy_shuffle(x.u8, (operand >> 29) & 3, xybytes);
xy_shuffle(y.u8, (operand >> 27) & 3, xybytes);
uint64_t x_enable = parse_writemask(operand >> 32, xybytes, 9);
uint64_t y_enable = parse_writemask((((operand >> 23) & 7) << 6) | (operand >> 58), xybytes, 9);
int32_t omask = -1;
if (((operand >> (32+6)) & 7) == 0) {
uint32_t val = (operand >> 32) & 0x3F;
if (val == 3) {
omask = 0;
} else if (val == 4 || val == 5) {
memset(&x, 0, 64);
}
}
if (((operand >> 23) & 7) == 0) {
uint32_t val = (operand >> 58) & 0x3F;
if (val == 3) {
omask = 0;
} else if (val == 4 || val == 5) {
memset(&y, 0, 64);
}
}
uint64_t z_row = (operand >> 20) & 7;
if (zbits == 16) {
if (bf16) {
...
} else {
for (uint32_t j = 0; j < 32; j += 1) {
if (!((y_enable >> (j*xybytes)) & 1)) continue;
for (uint32_t i = 0; i < 32; i += 1) {
if (!((x_enable >> (i*xybytes)) & 1)) continue;
_Float16* z = &state->z[bit_select(j*2, z_row, 1)].f16[i];
*z = omask ? vecfp_alu16(x.f16[i], y.f16[j], *z, alumode) : 0;
}
}
}
} else {
...
}
}
_Float16 vecfp_alu16(_Float16 x, _Float16 y, _Float16 z, int alumode) {
switch (alumode) {
case 0: __asm("fmadd %h0, %h1, %h2, %h3" : "=w"(z) : "w"(x), "w"(y), "w"(z)); break;
case 1: __asm("fmsub %h0, %h1, %h2, %h3" : "=w"(z) : "w"(x), "w"(y), "w"(z)); break;
case 4: z = (x <= (_Float16)0) ? (_Float16)0 : y; break;
}
return z;
}
Note that a fused-multiply-add counts as two floating-point operations. A measurement of 1.0 GFLOPS would mean 109 floating-point operations per second. The measurements are done without any load or store instructions; real-world workloads will need loads and stores, and thus will achieve lower numbers.
X and Y being f16[32]
, each Z accumulator being f16[32][32]
, ALU operation being z + x*y
or z - x*y
:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 per thread | 1444.6 GFLOPS | 2945.9 GFLOPS | 2668.9 GFLOPS | 4296.1 GFLOPS | 4692.2 GFLOPS | 5082.6 GFLOPS |
2 per thread | 2944.7 GFLOPS | 5856.6 GFLOPS | 4857.7 GFLOPS | 6150.3 GFLOPS | 5565.6 GFLOPS | 6186.8 GFLOPS |
X and Y being f16[32]
, each Z accumulator being f32[32][32]
, ALU operation being z + x*y
or z - x*y
:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 per thread | 1466.1 GFLOPS | 2926.1 GFLOPS | 2665.2 GFLOPS | 2856.6 GFLOPS | 2835.5 GFLOPS | 2874.2 GFLOPS |
X and Y being f32[16]
, each Z accumulator being f32[16][16]
, ALU operation being z + x*y
or z - x*y
:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 per thread | 367.0 GFLOPS | 725.8 GFLOPS | 923.2 GFLOPS | 1053.6 GFLOPS | 1572.3 GFLOPS | 1418.4 GFLOPS |
2 per thread | 735.4 GFLOPS | 1474.6 GFLOPS | 1321.7 GFLOPS | 1790.2 GFLOPS | 2708.2 GFLOPS | 2673.2 GFLOPS |
3 per thread | 1095.0 GFLOPS | 2215.9 GFLOPS | 2010.7 GFLOPS | 2469.2 GFLOPS | 2865.7 GFLOPS | 2861.8 GFLOPS |
4 per thread | 1478.4 GFLOPS | 2955.3 GFLOPS | 2771.5 GFLOPS | 2786.5 GFLOPS | 2903.5 GFLOPS | 2963.8 GFLOPS |
X and Y being f64[8]
, each Z accumulator being f64[8][8]
, ALU operation being z + x*y
or z - x*y
:
Z Accumulators | 1 Thread | 2 Threads | 3 Threads | 4 Threads | 5 Threads | 6 Threads |
---|---|---|---|---|---|---|
1 per thread | 91.6 GFLOPS | 184.9 GFLOPS | 210.1 GFLOPS | 329.7 GFLOPS | 411.7 GFLOPS | 408.6 GFLOPS |
2 per thread | 184.7 GFLOPS | 369.7 GFLOPS | 334.1 GFLOPS | 491.7 GFLOPS | 720.2 GFLOPS | 712.6 GFLOPS |
3 per thread | 276.2 GFLOPS | 553.2 GFLOPS | 546.5 GFLOPS | 652.4 GFLOPS | 757.9 GFLOPS | 685.1 GFLOPS |
4 per thread | 364.6 GFLOPS | 736.5 GFLOPS | 702.4 GFLOPS | 770.0 GFLOPS | 767.5 GFLOPS | 754.2 GFLOPS |
5 per thread | 368.3 GFLOPS | 731.0 GFLOPS | 596.8 GFLOPS | 763.6 GFLOPS | 793.2 GFLOPS | 710.5 GFLOPS |
6 per thread | 369.5 GFLOPS | 737.3 GFLOPS | 776.0 GFLOPS | 768.2 GFLOPS | 794.7 GFLOPS | 787.6 GFLOPS |
7 per thread | 369.1 GFLOPS | 738.6 GFLOPS | 606.1 GFLOPS | 708.7 GFLOPS | 787.0 GFLOPS | 792.2 GFLOPS |
8 per thread | 369.6 GFLOPS | 736.8 GFLOPS | 686.6 GFLOPS | 773.6 GFLOPS | 779.9 GFLOPS | 790.0 GFLOPS |