Skip to content

Commit

Permalink
Add gemm correctness test and improve the performance test
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 26, 2023
1 parent 5cd97c3 commit 6ece864
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 47 deletions.
225 changes: 181 additions & 44 deletions examples/mps/matrix-multiplication/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,124 @@
use metal::mps::*;
use metal::*;
use rand::{thread_rng, Rng};
use std::io::Write;
use std::ops::{AddAssign, Mul};
use std::{array, io};

fn main() {
correctness();
performance();
}

fn correctness() {
// First verify the correctness of the naive solution
let a = Matrix::new([1, 2, 6, 24, 120, 720], 3, 2);
let b = Matrix::new([1, 2, 3, 5, 8, 13], 2, 3);
let result = matrix_mul::<Int32>(a, b);
assert_eq!(
result.entries(),
&[11, 18, 29, 126, 204, 330, 3720, 6000, 9720]
);

const M: u64 = 100;
const N: u64 = 100;
const K: u64 = 100;
const ITERATIONS: usize = 50;

let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();

println!("Correctness: ");
for i in 0..ITERATIONS {
progress_bar(i, ITERATIONS);

let left = generate_matrix::<Float32, M, K>();
let right = generate_matrix::<Float32, K, N>();

let command_buffer = command_queue.new_command_buffer();
let result = encode_gemm(
&device,
command_buffer,
false,
false,
&left,
&right,
1.0,
0.0,
);
command_buffer.commit();
command_buffer.wait_until_completed();

let expected = matrix_mul(left, right);
approx_eq(result.contents(), expected.entries().to_vec());
}

println!(" ✅\n");
}

fn performance() {
const M: u64 = 4096;
const N: u64 = 4096;
const K: u64 = 4096;

const ITERATIONS: usize = 50;

println!("Performance: ");
println!("Generating input matrices: (f32 {M}x{K} and f16 {K}x{N})");
// Generate random matrices
let left = generate_matrix::<Float32, M, K>();
let right = generate_matrix::<Float16, K, N>();

// Setup
let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();

let cases = [
(false, false, 1.0, 0.0),
(true, false, 1.0, 0.0),
(false, true, 1.0, 0.0),
(false, false, 0.5, 0.0),
(false, false, 1.0, 0.5),
];
for (t_left, t_right, alpha, beta) in cases {
println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}");
let mut flops: Vec<f64> = vec![];

let mut total_time = std::time::Duration::new(0, 0);
for i in 0..ITERATIONS {
progress_bar(i, ITERATIONS);

let start = std::time::Instant::now();
let command_buffer = command_queue.new_command_buffer();
let _ = encode_gemm(
&device,
command_buffer,
t_left,
t_right,
&left,
&right,
alpha,
beta,
);
command_buffer.commit();
command_buffer.wait_until_completed();

let time = std::time::Instant::now() - start;

// Calculate GFLOPS
// C <- alpha * AB + beta * C
// Operations = 2(M * N * K)
flops.push((M * N * (2 * K + 2)) as f64 / (time.as_secs_f64() * 1e+9f64));
}
println!(" ✅");

let avg_gflops = flops.iter().sum::<f64>() / flops.len() as f64;
println!("Avg GFLOPS: {}", avg_gflops);
println!("Total time: {:#?}", total_time);
println!("Avg time: {:#?}", total_time / ITERATIONS as u32);
println!()
}
}

fn generate_matrix<T, const ROWS: u64, const COLS: u64>() -> Matrix<T>
where
Expand All @@ -9,59 +127,78 @@ where
{
let mut rng = thread_rng();
Matrix::new(
(0..ROWS * COLS).map(|_| T::from_f64(rng.gen())).collect(),
(0..ROWS * COLS).map(|_| T::from_f64(rng.gen())),
ROWS as NSUInteger,
COLS as NSUInteger,
)
}

fn main() {
const M: u64 = 4096;
const N: u64 = 4096;
const K: u64 = 4096;
const RUNS: u64 = 100;
// Naive matrix multiplication for testing
fn matrix_mul<T: MPSDataType>(a: Matrix<T>, b: Matrix<T>) -> Matrix<T>
where
T::Type: AddAssign + Mul<Output = T::Type> + Copy,
{
assert_eq!(a.columns(), b.rows());
let sum_count = a.columns() as usize;
let rows = a.rows() as usize;
let columns = b.columns() as usize;
let size = rows * columns;

let transpose_left = false;
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;
let mut entries = Vec::with_capacity(size);

// Generate random matrices
let left = generate_matrix::<Float32, M, K>();
let right = generate_matrix::<Float32, K, N>();
for idx in 0..size {
let i = idx / rows;
let j = idx % columns;

// Setup
let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();
let mut total_time = std::time::Duration::new(0, 0);
let mut sum = T::from_f64(0.0);
for di in 0..sum_count {
sum += a.entry(i, di) * b.entry(di, j);
}
entries.push(sum);
}

for _ in 0..RUNS {
let command_buffer = command_queue.new_command_buffer();
let start = std::time::Instant::now();
let _ = encode_gemm(
&device,
command_buffer,
transpose_left,
transpose_right,
&left,
&right,
alpha,
beta,
);
command_buffer.commit();
command_buffer.wait_until_completed();
let time = std::time::Instant::now() - start;
total_time += time;
Matrix::new(entries, a.rows(), b.columns())
}

fn euclidean_distance<T>(a: Vec<T>, b: Vec<T>) -> f64
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let mut sum = 0.0;

for i in 0..a.len() {
sum += (a[i].into() - b[i].into()).powi(2);
}

// Calculate GFLOPS
// C <- alpha * AB + beta * C
// Operations = M * N * (K+2) + M * N * K
let ops_count = M * N * (2 * K + 2);
let ops_count = (ops_count * RUNS) as f64;
let gflops = ops_count / (total_time.as_secs_f64() * 1000e+3f64);
// TODO: Something is wrong here hehe
println!("GFLOPS: {}", gflops);
println!("Total time: {:?}", total_time);
println!("Avg time: {:?}", total_time / RUNS as u32);
sum.sqrt()
}

fn approx_eq<T>(a: Vec<T>, b: Vec<T>)
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let avg_magnitude = 0.004f64;
let avg_deviation = (a.len() as f64).sqrt();
let tolerance = avg_magnitude.max(avg_deviation * 3e-7);

let distance = euclidean_distance(a, b);
assert!(
distance < tolerance,
"Distance not less than tolerance: {} < {} ",
distance,
tolerance
);
}

fn progress_bar(i: usize, len: usize) {
print!("\r");
print!("[");
print!("{}", "=".repeat(i));
print!("{}", " ".repeat(len - i - 1));
print!("]");
io::stdout().flush().unwrap();
}
25 changes: 22 additions & 3 deletions src/mps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,16 +908,35 @@ pub struct Matrix<T: MPSDataType> {
}

impl<T: MPSDataType> Matrix<T> {
pub fn new(entries: Vec<T::Type>, rows: NSUInteger, columns: NSUInteger) -> Self {
pub fn new<E: IntoIterator<Item = T::Type>>(
entries: E,
rows: NSUInteger,
columns: NSUInteger,
) -> Matrix<T> {
let entries: Vec<T::Type> = entries.into_iter().collect();
assert_eq!(entries.len(), rows as usize * columns as usize);
Self {
entries,
rows,
columns,
}
}
pub fn entries(&self) -> Vec<T::Type> {
self.entries.clone()
pub fn entries(&self) -> &[T::Type] {
&self.entries
}

pub fn entry(&self, row: usize, column: usize) -> T::Type {
assert!(row < self.rows as usize);
assert!(column < self.columns as usize);
self.entries[row * self.columns as usize + column]
}

pub fn rows(&self) -> NSUInteger {
self.rows
}

pub fn columns(&self) -> NSUInteger {
self.columns
}
}

Expand Down

0 comments on commit 6ece864

Please sign in to comment.