diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 6d54c860d..53895d4c4 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -2,11 +2,11 @@ use candle_core::Device; use clap::Parser; use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table}; use mistralrs_core::{ - initialize_logging, paged_attn_supported, parse_isq_value, Constraint, DefaultSchedulerMethod, - DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, IsqType, Loader, LoaderBuilder, - MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, - PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams, SchedulerConfig, - TokenSource, Usage, + get_model_dtype, initialize_logging, paged_attn_supported, parse_isq_value, Constraint, + DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, IsqType, + Loader, LoaderBuilder, MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, + NormalRequest, PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams, + SchedulerConfig, TokenSource, Usage, }; use std::sync::Arc; use std::{fmt::Display, num::NonZeroUsize}; @@ -348,6 +348,8 @@ fn main() -> anyhow::Result<()> { None => None, }; + let dtype = get_model_dtype(&args.model)?; + let loader: Box = LoaderBuilder::new(args.model) .with_use_flash_attn(use_flash_attn) .with_prompt_batchsize(prompt_batchsize) @@ -477,7 +479,7 @@ fn main() -> anyhow::Result<()> { let pipeline = loader.load_model_from_hf( None, token_source, - &ModelDType::Auto, + &dtype, &device, false, mapper,