Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Metal, CUDA Candle impls and ISQ #816

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

EricLBuehler
Copy link
Owner

Metal: MLX
CUDA: PaddedData for quantized

@ChristianWeyer could you please test if this builds on Metal?

Copy link

github-actions bot commented Oct 2, 2024

Code Metrics Report
  ===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                2           35           28            0            7
 Dockerfile              1           34           25            0            9
 Happy                   1          442          369            0           73
 JSON                   12          105          104            0            1
 Python                 52         2268         1930           69          269
 TOML                   20          625          559            2           64
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       4            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          196          169            1           26
 (Total)                            273          201           32           40
-------------------------------------------------------------------------------
 Markdown               38         2760            0         2094          666
 |- BASH                 6          103          100            0            3
 |- JSON                 1           12           12            0            0
 |- Python               5           92           82            0           10
 |- Rust                 9          322          274            0           48
 |- TOML                 2           75           63            0           12
 (Total)                           3364          531         2094          739
-------------------------------------------------------------------------------
 Rust                  260        75643        68177         1547         5919
 |- Markdown           123         1217           25         1117           75
 (Total)                          76860        68202         2664         5994
===============================================================================
 Total                 393        81933        71211         3714         7008
===============================================================================
  

@ChristianWeyer
Copy link

@EricLBuehler

error[E0412]: cannot find type `NSUInteger` in this scope
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/quantized/metal.rs:211:65
    |
211 |         let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
    |                                                                 ^^^^^^^^^^ not found in this scope
    |
help: consider importing this type alias
    |
1   + use metal::NSUInteger;
    |

error[E0277]: `Option<f64>` doesn't implement `std::fmt::Display`
    --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/metal_backend/mod.rs:1581:83
     |
1581 |                     MetalError::Message(format!("mlx matmul doesn't support alpha {s}")).into(),
     |                                                                                   ^^^ `Option<f64>` cannot be formatted with the default formatter
     |
     = help: the trait `std::fmt::Display` is not implemented for `Option<f64>`
     = note: in format strings you may be able to use `{:?}` (or {:#?} for pretty-print) instead
     = note: this error originates in the macro `$crate::__export::format_args` which comes from the expansion of the macro `format` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: `Option<f64>` doesn't implement `std::fmt::Display`
    --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/metal_backend/mod.rs:1602:83
     |
1602 |                     MetalError::Message(format!("mlx matmul doesn't support alpha {s}")).into(),
     |                                                                                   ^^^ `Option<f64>` cannot be formatted with the default formatter
     |
     = help: the trait `std::fmt::Display` is not implemented for `Option<f64>`
     = note: in format strings you may be able to use `{:?}` (or {:#?} for pretty-print) instead
     = note: this error originates in the macro `$crate::__export::format_args` which comes from the expansion of the macro `format` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0609]: no field `count` on type `&metal::QMetalStorage`
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/quantized/metal.rs:211:26
    |
211 |         let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
    |                          ^^^^^ unknown field
    |
    = note: available fields are: `dtype`, `device`, `buffer`

error[E0599]: no method named `size_in_bytes` found for enum `quantized::GgmlDType` in the current scope
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/quantized/metal.rs:211:45
    |
211 |         let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
    |                                             ^^^^^^^^^^^^^ method not found in `GgmlDType`
    |
   ::: /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/quantized/mod.rs:144:1
    |
144 | pub enum GgmlDType {
    | ------------------ method `size_in_bytes` not found for this enum

error[E0609]: no field `count` on type `&metal::QMetalStorage`
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/quantized/metal.rs:223:44
    |
223 |         Ok(read_to_vec::<u8>(&buffer, self.count))
    |                                            ^^^^^ unknown field
    |
    = note: available fields are: `dtype`, `device`, `buffer`

error[E0004]: non-exhaustive patterns: `DType::I16` and `DType::I32` not covered
    --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/metal_backend/mod.rs:2127:26
     |
2127 |         let name = match dtype {
     |                          ^^^^^ patterns `DType::I16` and `DType::I32` not covered
     |
note: `DType` defined here
    --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/dtype.rs:8:10
     |
8    | pub enum DType {
     |          ^^^^^
...
14   |     I16,
     |     --- not covered
15   |     // Signed 32 bits integer.
16   |     I32,
     |     --- not covered
     = note: the matched value is of type `DType`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern, a match arm with multiple or-patterns as shown, or multiple match arms
     |
2137 ~             },
2138 +             DType::I16 | DType::I32 => todo!()
     |

   Compiling ureq v2.10.1
error[E0004]: non-exhaustive patterns: `DType::I16` not covered
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/sort.rs:146:23
    |
146 |                 match storage.dtype() {
    |                       ^^^^^^^^^^^^^^^ pattern `DType::I16` not covered
    |
note: `DType` defined here
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/dtype.rs:8:10
    |
8   | pub enum DType {
    |          ^^^^^
...
14  |     I16,
    |     --- not covered
    = note: the matched value is of type `DType`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern or an explicit pattern as shown
    |
154 ~                     DType::I32 => "asort_asc_i32",
155 ~                     DType::I16 => todo!(),
    |

error[E0004]: non-exhaustive patterns: `DType::I16` not covered
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/sort.rs:157:23
    |
157 |                 match storage.dtype() {
    |                       ^^^^^^^^^^^^^^^ pattern `DType::I16` not covered
    |
note: `DType` defined here
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/d08212c/candle-core/src/dtype.rs:8:10
    |
8   | pub enum DType {
    |          ^^^^^
...
14  |     I16,
    |     --- not covered
    = note: the matched value is of type `DType`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern or an explicit pattern as shown
    |
165 ~                     DType::I32 => "asort_desc_i32",
166 ~                     DType::I16 => todo!(),
    |

   Compiling hf-hub v0.3.2
Some errors have detailed explanations: E0004, E0277, E0412, E0599, E0609.
For more information about an error, try `rustc --explain E0004`.
error: could not compile `candle-core` (lib) due to 9 previous errors
warning: build failed, waiting for other jobs to finish...

@EricLBuehler
Copy link
Owner Author

EricLBuehler commented Oct 2, 2024

Hi @ChristianWeyer thanks for testing it! I pushed some changes which should hopefully fix this, can you please test it again?

I made some changes to things which will affect the UQFF backend - could you also please quickly test that:

cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --write-uqff test.uqff

And then:

cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --from-uqff test.uqff

Sorry for the inconvenience! My Metal hardware should be arriving soon :)

@ChristianWeyer
Copy link

Hi @ChristianWeyer thanks for testing it! I pushed some changes which should hopefully fix this, can you please test it again?

I made some changes to things which will affect the UQFF backend - could you also please quickly test that:

cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --write-uqff test.uqff

Here we go:

error[E0308]: mismatched types
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/c04861d/candle-core/src/quantized/metal.rs:223:39
    |
223 |         Ok(read_to_vec::<u8>(&buffer, self.buffer.length()))
    |            -----------------          ^^^^^^^^^^^^^^^^^^^^ expected `usize`, found `u64`
    |            |
    |            arguments to this function are incorrect
    |
note: function defined here
   --> /Users/christianweyer/.cargo/git/checkouts/candle-c6a149c3b35a488f/c04861d/candle-core/src/quantized/metal.rs:240:4
    |
240 | fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
    |    ^^^^^^^^^^^                            --------
help: you can convert a `u64` to a `usize` and panic if the converted value doesn't fit
    |
223 |         Ok(read_to_vec::<u8>(&buffer, self.buffer.length().try_into().unwrap()))
    |                                                           ++++++++++++++++++++

   Compiling rustls-webpki v0.102.8
For more information about this error, try `rustc --explain E0308`.
error: could not compile `candle-core` (lib) due to 1 previous error
warning: build failed, waiting for other jobs to finish...

@EricLBuehler
Copy link
Owner Author

@ChristianWeyer thanks, I added a quick cast - could you please try it again?

@ChristianWeyer
Copy link

warning: variable does not need to be mutable
   --> mistralrs-core/src/pipeline/isq.rs:144:18
    |
144 |             let (mut tensors, mapper) = match organization {
    |                  ----^^^^^^^
    |                  |
    |                  help: remove this `mut`
    |
    = note: `#[warn(unused_mut)]` on by default

error[E0382]: borrow of moved value: `tensors`
    --> mistralrs-core/src/pipeline/isq.rs:401:25
     |
144  |               let (mut tensors, mapper) = match organization {
     |                    ----------- move occurs because `tensors` has type `Vec<(&mut Arc<dyn QuantMethod>, std::option::Option<usize>)>`, which does not implement the `Copy` trait
...
342  |                       tensors.into_iter().zip(devices_and_dtypes).for_each(
     |                               ----------- `tensors` moved due to this method call
...
353  |                           .into_iter()
     |                            ----------- `tensors` moved due to this method call
...
401  | /                         tensors
402  | |                             .iter()
     | |___________________________________^ value borrowed here after move
     |
note: `std::iter::IntoIterator::into_iter` takes ownership of the receiver `self`, which moves `tensors`
    --> /Users/christianweyer/.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/iter/traits/collect.rs:346:18
     |
346  |     fn into_iter(self) -> Self::IntoIter;
     |                  ^^^^
     = note: borrow occurs due to deref coercion to `[(&mut Arc<dyn QuantMethod>, std::option::Option<usize>)]`
note: deref defined here
    --> /Users/christianweyer/.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/alloc/src/vec/mod.rs:2818:5
     |
2818 |     type Target = [T];
     |     ^^^^^^^^^^^
help: you could `clone` the value and consume it, if the `&mut Arc<dyn QuantMethod>: Clone` trait bound could be satisfied
     |
352  |                     tensors.clone()
     |                            ++++++++
help: you could `clone` the value and consume it, if the `&mut Arc<dyn QuantMethod>: Clone` trait bound could be satisfied
     |
342  |                     tensors.clone().into_iter().zip(devices_and_dtypes).for_each(
     |                            ++++++++

For more information about this error, try `rustc --explain E0382`.
warning: `mistralrs-core` (lib) generated 1 warning
error: could not compile `mistralrs-core` (lib) due to 1 previous error; 1 warning emitted
warning: build failed, waiting for other jobs to finish...

@EricLBuehler
Copy link
Owner Author

@ChristianWeyer I think this should compile now, can you please test it :)?

Also, please, the UQFF:

cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --write-uqff test.uqff

And then:

cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --from-uqff test.uqff

Thanks!

@ChristianWeyer
Copy link

❯ cargo run --features metal --release -- -i --isq Q4K plain -m microsoft/Phi-3.5-mini-instruct --write-uqff test.uqff
   Compiling mistralrs-core v0.3.1 (/Users/christianweyer/Sources/mistral.rs/mistralrs-core)
error[E0308]: mismatched types
   --> mistralrs-core/src/pipeline/isq.rs:344:39
    |
344 |   ...                   *tensor = tensor
    |  _______________________-------___^
    | |                       |
    | |                       expected due to the type of this binding
345 | | ...                       .clone()
346 | | ...                       .apply_isq(dtype, device.clone(), &n_quantized)
347 | | ...                       .unwrap();
    | |___________________________________^ expected `&mut Arc<dyn QuantMethod>`, found `Arc<dyn QuantMethod>`
    |
    = note: expected mutable reference `&mut Arc<(dyn QuantMethod + 'static)>`
                          found struct `Arc<dyn QuantMethod>`
help: consider dereferencing here to assign to the mutably borrowed value
    |
344 |                             **tensor = tensor
    |                             +

error[E0308]: mismatched types
   --> mistralrs-core/src/pipeline/isq.rs:357:39
    |
357 |   ...                   *tensor = tensor
    |  _______________________-------___^
    | |                       |
    | |                       expected due to the type of this binding
358 | | ...                       .clone()
359 | | ...                       .apply_isq(dtype, device.clone(), &n_quantized)
360 | | ...                       .unwrap();
    | |___________________________________^ expected `&mut Arc<dyn QuantMethod>`, found `Arc<dyn QuantMethod>`
    |
    = note: expected mutable reference `&mut Arc<(dyn QuantMethod + 'static)>`
                          found struct `Arc<dyn QuantMethod>`
help: consider dereferencing here to assign to the mutably borrowed value
    |
357 |                             **tensor = tensor
    |                             +

For more information about this error, try `rustc --explain E0308`.
error: could not compile `mistralrs-core` (lib) due to 2 previous errors

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants