Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make things pub #292

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
09c5333
Initial matrix multiplication support. Still some bugs to iron out bu…
ivarflakstad Oct 18, 2023
4f4df06
Remove over the top abstractions. Changed test to example
ivarflakstad Oct 20, 2023
530e2c1
Use "apply" scheme for gemm function. Buffer reading bug needs fix.
ivarflakstad Oct 21, 2023
5cd97c3
Matrix multiplication example is now a performance test. Seems to wor…
ivarflakstad Oct 25, 2023
aed9b29
Add gemm correctness test and improve the performance test
ivarflakstad Oct 26, 2023
2c1490c
small improvements
ivarflakstad Oct 30, 2023
313a768
Merge branch 'gfx-rs:master' into mps-matrix-multiplication-kernel
ivarflakstad Oct 30, 2023
c130b1c
Merge pull request #1 from ivarflakstad/mps-matrix-multiplication-kernel
ivarflakstad Oct 31, 2023
7791312
wait_until_completed() outside of inner loop
ivarflakstad Oct 31, 2023
933b3b8
Easier to change matrix types in gemm benchmark
ivarflakstad Oct 31, 2023
c667fc4
Merge pull request #2 from ivarflakstad/mps-improve-gemm-benchmark
ivarflakstad Oct 31, 2023
36ae080
Ditch generic Matrix and use MatrixBuffer instead. Create buffers sep…
ivarflakstad Oct 31, 2023
7910f10
Mark encode_gemm c parameter as &mut
ivarflakstad Oct 31, 2023
aaf2647
Return Result from encode_gemm
ivarflakstad Oct 31, 2023
ad60768
MPSDataType TYPE_ID -> u32
ivarflakstad Oct 31, 2023
3864ca7
Merge pull request #3 from ivarflakstad/mps-gemm-matrixbuffer-improve…
ivarflakstad Oct 31, 2023
a2aff24
Merge branch 'master' into mpsdatatype-should-be-u32
ivarflakstad Oct 31, 2023
3a4bd86
Merge pull request #4 from ivarflakstad/mpsdatatype-should-be-u32
ivarflakstad Oct 31, 2023
5a66ce1
feat: Make the features testable and threadsafe.
Narsil Oct 31, 2023
b2ca05d
Merge pull request #5 from ivarflakstad/fix_threading_autorelease
ivarflakstad Oct 31, 2023
bc755b4
Some cleanup.
Narsil Oct 31, 2023
adeb7c4
Remove the nodrop thing (Just manually increment the reference count).
Narsil Nov 1, 2023
c1df369
Merge pull request #6 from ivarflakstad/some_cleanup
ivarflakstad Nov 1, 2023
cec862c
Making things public.
Narsil Nov 1, 2023
09406d8
Making some changes to be able to own the command buffer (not sure if…
Narsil Nov 1, 2023
3a0b6fb
Fixing things to get owned command buffer.
Narsil Nov 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ block = "0.1.6"
foreign-types = "0.5"
dispatch = { version = "0.2", optional = true }
paste = "1"
half = "2.3.1"

[dependencies.objc]
version = "0.2.4"
Expand Down Expand Up @@ -76,7 +77,13 @@ name = "compute"
path = "examples/compute/main.rs"

[[example]]
name = "mps"
name = "mps-matrix-multiplication"
path = "examples/mps/matrix-multiplication/main.rs"
required-features = ["mps"]

[[example]]
name = "mps-ray-intersection"
path = "examples/mps/ray-intersection/main.rs"
required-features = ["mps"]

[[example]]
Expand Down
2 changes: 1 addition & 1 deletion examples/compute/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn main() {
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

resolve_samples_into_buffer(command_buffer, &counter_sample_buffer, &destination_buffer);
resolve_samples_into_buffer(&command_buffer, &counter_sample_buffer, &destination_buffer);

command_buffer.commit();
command_buffer.wait_until_completed();
Expand Down
245 changes: 245 additions & 0 deletions examples/mps/matrix-multiplication/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
use std::any::type_name;
use std::io;
use std::io::Write;
use std::ops::{AddAssign, Mul};

use rand::{thread_rng, Rng};

use metal::mps::matrix::*;
use metal::mps::*;
use metal::*;

fn main() {
correctness();
performance();
}

fn correctness() {
// First verify the correctness of the naive solution
let m = 3;
let n = 3;
let k = 2;
let a = vec![1, 2, 6, 24, 120, 720];
let b = vec![1, 2, 6, 24, 120, 720];
let result = matrix_mul::<Int32>(a, b, m, n, k);
assert_eq!(
result,
&[49, 242, 1446, 582, 2892, 17316, 17400, 86640, 519120]
);

const M: u64 = 100;
const N: u64 = 100;
const K: u64 = 100;
const ITERATIONS: usize = 50;

let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();

println!("Correctness: ");
for i in 0..ITERATIONS {
progress_bar(i, ITERATIONS);

let a = generate_matrix::<Float32, M, K>(&device);
let b = generate_matrix::<Float32, K, N>(&device);
let mut c = generate_matrix::<Float32, K, N>(&device);

let command_buffer = command_queue.new_command_buffer();
encode_gemm(
&device,
&command_buffer,
false,
false,
&a,
&b,
&mut c,
1.0,
0.0,
)
.expect("Encoding failed");
command_buffer.commit();
command_buffer.wait_until_completed();

let expected = matrix_mul::<Float32>(
a.contents(),
b.contents(),
M as usize,
K as usize,
N as usize,
);
approx_eq(c.contents(), expected);
}

println!(" ✅\n");
}

fn short_type_name<T>() -> String {
let name = type_name::<T>();
let parts = name.split("::");
parts.last().unwrap().to_string()
}

fn performance() {
const M: u64 = 4096;
const N: u64 = 4096;
const K: u64 = 4096;

type A = Float32;
type B = Float16;
type C = Float32;
const ITERATIONS: usize = 50;

println!("Performance: ");

let a_tname = short_type_name::<A>();
let b_tname = short_type_name::<B>();
let c_tname = short_type_name::<C>();
println!("{M}x{K}x{a_tname} * {K}x{N}x{b_tname} = {M}x{N}x{c_tname}");

let device = Device::system_default().expect("No device found");

println!("Generating input matrices...");
// Generate random matrices
let a = generate_matrix::<A, M, K>(&device);
let b = generate_matrix::<B, K, N>(&device);
let mut c = generate_matrix::<C, K, N>(&device);

let cases = [
(false, false, 1.0, 0.0),
(true, false, 1.0, 0.0),
(false, true, 1.0, 0.0),
(false, false, 0.5, 0.0),
(false, false, 1.0, 0.5),
];
for (t_left, t_right, alpha, beta) in cases {
println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}");
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let start = std::time::Instant::now();
for i in 0..ITERATIONS {
progress_bar(i, ITERATIONS);

encode_gemm(
&device,
&command_buffer,
t_left,
t_right,
&a,
&b,
&mut c,
alpha,
beta,
)
.expect("Encoding failed");
}
command_buffer.commit();
command_buffer.wait_until_completed();

let total_time = start.elapsed();

// Calculate GFLOPS
// C <- alpha * AB + beta * C
// Operations = 2(M * N * K)
let avg_gflops = (ITERATIONS as u64 * (M * N * (2 * K - 1))) as f64
/ (total_time.as_secs_f64() * 1e+9f64);

println!(" ✅");

println!("Avg GFLOPS: {}", avg_gflops);
println!("Total time: {:#?}", total_time);
println!()
}
}

fn generate_matrix<T, const ROWS: u64, const COLS: u64>(device: &Device) -> MatrixBuffer<T>
where
T: MPSDataType,
GEMMInput<T>: Valid,
{
let mut rng = thread_rng();

// Create descriptors for the matrices.
let row_bytes_for_columns = MatrixDescriptor::row_bytes_for_columns(COLS, T::TYPE_ID);

// Create buffers
let options = MTLResourceOptions::StorageModeShared;
let data = (0..ROWS * COLS)
.map(|_| T::from_f64(rng.gen()))
.collect::<Vec<T::Type>>();
let buffer =
device.new_buffer_with_data(data.as_ptr().cast(), ROWS * row_bytes_for_columns, options);

MatrixBuffer::from_buffer(buffer, ROWS, COLS)
}

// Naive matrix multiplication for testing
fn matrix_mul<T: MPSDataType>(
a: Vec<T::Type>,
b: Vec<T::Type>,
m: usize,
n: usize,
k: usize,
) -> Vec<T::Type>
where
T::Type: AddAssign + Mul<Output = T::Type> + Copy,
{
let size = m * n;

let mut c = Vec::with_capacity(size);

for idx in 0..size {
let i = idx / m;
let j = idx % n;

let mut sum = T::from_f64(0.0);
for di in 0..k {
sum += a[(i * k) + di] * b[(di * n) + j];
}
c.push(sum);
}

c
}

fn euclidean_distance<T>(a: Vec<T>, b: Vec<T>) -> f64
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let mut sum = 0.0;

for i in 0..a.len() {
sum += (a[i].into() - b[i].into()).powi(2);
}

sum.sqrt()
}

fn approx_eq<T>(a: Vec<T>, b: Vec<T>)
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let avg_magnitude = 0.004f64;
let avg_deviation = (a.len() as f64).sqrt();
let tolerance = avg_magnitude.max(avg_deviation * 3e-7);

let distance = euclidean_distance(a, b);
assert!(
distance < tolerance,
"Distance not less than tolerance: {} < {} ",
distance,
tolerance
);
}

fn progress_bar(i: usize, len: usize) {
print!("\r");
print!("[");
print!("{}", "=".repeat(i));
print!("{}", " ".repeat(len - i - 1));
print!("]");
io::stdout().flush().unwrap();
}
26 changes: 13 additions & 13 deletions examples/mps/main.rs → examples/mps/ray-intersection/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use metal::mps::raytracing::*;
use metal::*;
use std::ffi::c_void;
use std::mem;
Expand All @@ -7,15 +8,15 @@ struct Vertex {
xyz: [f32; 3],
}

type Ray = mps::MPSRayOriginMinDistanceDirectionMaxDistance;
type Intersection = mps::MPSIntersectionDistancePrimitiveIndexCoordinates;
type Ray = MPSRayOriginMinDistanceDirectionMaxDistance;
type Intersection = MPSIntersectionDistancePrimitiveIndexCoordinates;

// Original example taken from https://sergeyreznik.github.io/metal-ray-tracer/part-1/index.html
fn main() {
let device = Device::system_default().expect("No device found");

let library_path =
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("examples/mps/shaders.metallib");
let library_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("examples/mps/ray-intersection/shaders.metallib");
let library = device
.new_library_with_file(library_path)
.expect("Failed to load shader library");
Expand Down Expand Up @@ -61,26 +62,25 @@ fn main() {
);

// Build an acceleration structure using our vertex and index buffers containing the single triangle.
let acceleration_structure = mps::TriangleAccelerationStructure::from_device(&device)
let acceleration_structure = TriangleAccelerationStructure::from_device(&device)
.expect("Failed to create acceleration structure");

acceleration_structure.set_vertex_buffer(Some(&vertex_buffer));
acceleration_structure.set_vertex_stride(vertex_stride as u64);
acceleration_structure.set_index_buffer(Some(&index_buffer));
acceleration_structure.set_index_type(mps::MPSDataType::UInt32);
acceleration_structure.set_index_type(mps::UInt32);
acceleration_structure.set_triangle_count(1);
acceleration_structure.set_usage(mps::MPSAccelerationStructureUsage::None);
acceleration_structure.set_usage(MPSAccelerationStructureUsage::None);
acceleration_structure.rebuild();

let ray_intersector =
mps::RayIntersector::from_device(&device).expect("Failed to create ray intersector");
RayIntersector::from_device(&device).expect("Failed to create ray intersector");

ray_intersector.set_ray_stride(mem::size_of::<Ray>() as u64);
ray_intersector.set_ray_data_type(mps::MPSRayDataType::OriginMinDistanceDirectionMaxDistance);
ray_intersector.set_ray_data_type(MPSRayDataType::OriginMinDistanceDirectionMaxDistance);
ray_intersector.set_intersection_stride(mem::size_of::<Intersection>() as u64);
ray_intersector.set_intersection_data_type(
mps::MPSIntersectionDataType::DistancePrimitiveIndexCoordinates,
);
ray_intersector
.set_intersection_data_type(MPSIntersectionDataType::DistancePrimitiveIndexCoordinates);

// Create a buffer to hold generated rays and intersection results
let ray_count = 1024;
Expand Down Expand Up @@ -115,7 +115,7 @@ fn main() {
// Intersect rays with triangles inside acceleration structure
ray_intersector.encode_intersection_to_command_buffer(
&command_buffer,
mps::MPSIntersectionType::Nearest,
MPSIntersectionType::Nearest,
&ray_buffer,
0,
&intersection_buffer,
Expand Down
File renamed without changes.
10 changes: 10 additions & 0 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,14 @@ impl BufferRef {
pub fn gpu_address(&self) -> u64 {
unsafe { msg_send![self, gpuAddress] }
}

pub fn read_to_slice<T>(&self, len: usize) -> &[T] {
let contents_ptr = self.contents() as *const T;
assert!(!contents_ptr.is_null());
unsafe { std::slice::from_raw_parts(contents_ptr, len) }
}

pub fn read_to_vec<T: Clone>(&self, len: usize) -> Vec<T> {
self.read_to_slice(len).to_vec()
}
}
8 changes: 8 additions & 0 deletions src/commandqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ impl CommandQueueRef {
unsafe { msg_send![self, commandBuffer] }
}

pub fn new_owned_command_buffer(&self) -> CommandBuffer {
unsafe {
let buffer: CommandBuffer = msg_send![self, commandBuffer];
let () = msg_send![buffer, retain];
buffer
}
}

pub fn new_command_buffer_with_unretained_references(&self) -> &CommandBufferRef {
unsafe { msg_send![self, commandBufferWithUnretainedReferences] }
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub extern crate foreign_types;
#[macro_use]
pub extern crate paste;

pub extern crate half;

use std::{
borrow::{Borrow, ToOwned},
marker::PhantomData,
Expand Down
Loading