diff --git a/Cargo.toml b/Cargo.toml index 25bc266..2de3277 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ llvm-sys = "120" log = { version = "0.4", features = ["release_max_level_info"] } pest = "2.1.3" pest_derive = "2.1.0" +ndarray = "0.13" pretty_env_logger = "0.4" rev_slice = "0.1.5" serde = { version = "1.0.123", features = ["derive"] } diff --git a/config/mempool.yaml b/config/mempool.yaml index 5ef488f..58627d5 100644 --- a/config/mempool.yaml +++ b/config/mempool.yaml @@ -11,28 +11,39 @@ address: nr_cores: 0x40000010 uart: 0xC0000000 # Not supported in MemPool - barrier_reg: 0x50000000 + barrier_reg: + start: 0x50000000 + offset: 0x100000 cluster_base_hartid: 0x50000001 cluster_num: 0x50000002 cluster_id: 0x50000003 cl_clint: 0x40000060 clint: 0xFFFF0000 memory: - - tcdm: - start: 0x0 - end: 0x100000 - latency: 5 - dram: - start: 0x80000000 - end: 0x80010000 - latency: 10 - # Not used in MemPool - ext_tcdm: [] - periphs: - callbacks: [] - end: 0x100000 - latency: 5 - start: 0x100000 + tcdm: + start: 0x0 + size: 0x100000 + offset: 0x100000 + latency: 5 + dram: + start: 0x80000000 + size: 0x01000000 + offset: 0x0 + latency: 10 + periphs: + start: 0x40000000 + size: 0x20000 + offset: 0x0 + latency: 5 + callbacks: + - name: zero-memory + size: 0x40 + - name: mempool-ita + size: 32 + - name: zero-memory + size: 0xFFA0 + - name: mempool-dma + size: 28 inst_latency: mul: 3 mulh: 3 diff --git a/src/engine.rs b/src/engine.rs index fa88bbd..08b7f19 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -713,7 +713,7 @@ impl<'a, 'b> Cpu<'a, 'b> { } } - fn binary_load(&self, addr: u32, size: u8) -> u32 { + pub fn binary_load(&self, addr: u32, size: u8) -> u32 { match addr { x if x == self.engine.config.address.tcdm_start => { self.engine.config.memory.tcdm.start @@ -801,7 +801,7 @@ impl<'a, 'b> Cpu<'a, 'b> { let periph_addr = addr - (self.engine.config.memory.periphs.start + self.engine.config.memory.periphs.offset * id as u32); - self.engine.peripherals.load(id, periph_addr, size) + self.engine.peripherals.load(&self, id, periph_addr, size) } // Bootrom x if x >= self.engine.config.bootrom.start @@ -854,7 +854,7 @@ impl<'a, 'b> Cpu<'a, 'b> { } } - fn binary_store(&self, addr: u32, value: u32, mask: u32, size: u8) { + pub fn binary_store(&self, addr: u32, value: u32, mask: u32, size: u8) { match addr { x if x == self.engine.config.address.tcdm_start => (), // tcdm_start x if x == self.engine.config.address.tcdm_end => (), // tcdm_end @@ -950,7 +950,7 @@ impl<'a, 'b> Cpu<'a, 'b> { + self.engine.config.memory.periphs.offset * id as u32); self.engine .peripherals - .store(id, periph_addr, value, mask, size) + .store(&self, id, periph_addr, value, mask, size) } // Bootrom x if x >= self.engine.config.bootrom.start @@ -1055,6 +1055,29 @@ impl<'a, 'b> Cpu<'a, 'b> { prev as u32 } + pub fn binary_memcpy(&self, mut dest: u32, mut src: u32, n: u32) { + // n in bytes + trace!("MEMCPY From {:08x} to {:08x} num: {:08x}", src, dest, n); + if dest % 4 == 0 && src % 4 == 0 && n % 4 == 0 { + warn!("MEMCPY aligned"); + // Aligned transfer + for _ in 0..n / 4 { + let tmp = self.binary_load(src, 2); + self.binary_store(dest, tmp, u32::MAX, 2); + src += 4; + dest += 4; + } + } else { + warn!("MEMCPY unaligned"); + for _ in 0..n { + let tmp = self.binary_load(src, 0); + self.binary_store(dest, tmp, (u8::MAX as u32) << (8 * (dest % 4)), 0); + src += 1; + dest += 1; + } + } + } + fn binary_csr_read(&self, csr: riscv::Csr, notrace: u32) -> u32 { if notrace == 0 { trace!("Read CSR {:?}", csr); diff --git a/src/peripherals.rs b/src/peripherals.rs index 9463c4b..bf9900e 100644 --- a/src/peripherals.rs +++ b/src/peripherals.rs @@ -4,7 +4,9 @@ /// Generic, memory-mapped peripherals implemented using runtime callbacks. use crate::configuration::Callback; -use std::sync::atomic::{AtomicU32, Ordering}; +use crate::Cpu; +use ndarray::{s, Array1, Array2, Array3}; +use std::sync::atomic::{AtomicI32, AtomicU32, Ordering}; use PeriphReq::{Load, Store}; /// Reference held by execution engine, referencing each peripheral instance in each cluster @@ -40,15 +42,22 @@ impl Peripherals { ); } - pub fn load(&self, cluster_id: usize, addr: u32, size: u8) -> u32 { - self.load_store(cluster_id, addr, size, Load) + pub fn load(&self, cpu: &Cpu, cluster_id: usize, addr: u32, size: u8) -> u32 { + self.load_store(cpu, cluster_id, addr, size, Load) } - pub fn store(&self, cluster_id: usize, addr: u32, value: u32, mask: u32, size: u8) { - self.load_store(cluster_id, addr, size, Store(value, mask)); + pub fn store(&self, cpu: &Cpu, cluster_id: usize, addr: u32, value: u32, mask: u32, size: u8) { + self.load_store(cpu, cluster_id, addr, size, Store(value, mask)); } - fn load_store(&self, cluster_id: usize, mut addr: u32, size: u8, req: PeriphReq) -> u32 { + fn load_store( + &self, + cpu: &Cpu, + cluster_id: usize, + mut addr: u32, + size: u8, + req: PeriphReq, + ) -> u32 { for i in &self.cluster_peripherals[cluster_id] { if addr < i.0 { return match req { @@ -60,7 +69,7 @@ impl Peripherals { addr, size ); - self.peripherals[i.1].load(addr, size) + self.peripherals[i.1].load(cpu, addr, size) } Store(val, mask) => { trace!( @@ -72,7 +81,7 @@ impl Peripherals { mask, val ); - self.peripherals[i.1].store(addr, val, mask, size); + self.peripherals[i.1].store(cpu, addr, val, mask, size); 0 } }; @@ -111,12 +120,12 @@ pub trait Peripheral { /// should return the same name as in the config file fn get_name(&self) -> &'static str; /// store instruction - fn store(&self, addr: u32, value: u32, mask: u32, size: u8); + fn store(&self, cpu: &Cpu, addr: u32, value: u32, mask: u32, size: u8); /// load instruction - fn load(&self, addr: u32, size: u8) -> u32; + fn load(&self, cpu: &Cpu, addr: u32, size: u8) -> u32; } -/// Function called by the engine to get the peripheral types. This function should +/// Function called by the cpu to get the peripheral types. This function should /// return a vector containing an instance of each available peripherable type. /// To add a new peripheral type, declare it below and add it here. pub fn get_peripheral_types() -> Vec> { @@ -124,6 +133,8 @@ pub fn get_peripheral_types() -> Vec> { Box::new(Semaphores::default()), Box::new(Fence::default()), Box::new(ZeroMemory::default()), + Box::new(MemPoolDMA::default()), + Box::new(MemPoolITA::default()), ] } @@ -138,14 +149,14 @@ impl Peripheral for Fence { "fence" } - fn store(&self, addr: u32, val: u32, _mask: u32, _: u8) { + fn store(&self, _cpu: &Cpu, addr: u32, val: u32, _mask: u32, _: u8) { match addr { 0x0 => self.set.store(val, Ordering::SeqCst), _ => self.current.store(val, Ordering::SeqCst), } } - fn load(&self, _: u32, _: u8) -> u32 { + fn load(&self, _cpu: &Cpu, _: u32, _: u8) -> u32 { self.current.fetch_add(1, Ordering::SeqCst); while self.set.load(Ordering::SeqCst) != self.current.load(Ordering::SeqCst) {} 0 @@ -164,7 +175,7 @@ impl Peripheral for Semaphores { "semaphores" } - fn store(&self, addr: u32, val: u32, _mask: u32, _: u8) { + fn store(&self, _cpu: &Cpu, addr: u32, val: u32, _mask: u32, _: u8) { match addr { 0x0 => self.empty_count.store(val, Ordering::SeqCst), 0x4 => { @@ -220,7 +231,7 @@ impl Peripheral for Semaphores { } } - fn load(&self, _: u32, _: u8) -> u32 { + fn load(&self, _cpu: &Cpu, _: u32, _: u8) -> u32 { 0 } } @@ -233,9 +244,668 @@ impl Peripheral for ZeroMemory { "zero-memory" } - fn store(&self, _: u32, _: u32, _: u32, _: u8) {} + fn store(&self, _cpu: &Cpu, _: u32, _: u32, _: u32, _: u8) {} - fn load(&self, _: u32, _: u8) -> u32 { + fn load(&self, _cpu: &Cpu, _: u32, _: u8) -> u32 { 0 } } + +#[derive(Default)] +struct MemPoolDMA { + src_addr: AtomicU32, + dst_addr: AtomicU32, + num_bytes: AtomicU32, + conf: AtomicU32, + status: AtomicU32, + next_id: AtomicU32, + done: AtomicU32, +} + +impl Peripheral for MemPoolDMA { + /// should return the same name as in the config file + fn get_name(&self) -> &'static str { + "mempool-dma" + } + /// store instruction + fn store(&self, _cpu: &Cpu, addr: u32, value: u32, _mask: u32, _size: u8) { + match addr { + 0x00 => self.src_addr.store(value, Ordering::SeqCst), + 0x04 => self.dst_addr.store(value, Ordering::SeqCst), + 0x08 => self.num_bytes.store(value, Ordering::SeqCst), + 0x0C => self.conf.store(value, Ordering::SeqCst), + 0x10 => (), /* status: Write has no effect */ + 0x14 => (), /* next_id: Write has no effect */ + 0x18 => (), /* done: Write has no effect */ + _ => unimplemented!(), + } + self.done.store(0, Ordering::SeqCst); + } + /// load instruction + fn load(&self, cpu: &Cpu, addr: u32, _size: u8) -> u32 { + match addr { + 0x00 => self.src_addr.load(Ordering::SeqCst), + 0x04 => self.dst_addr.load(Ordering::SeqCst), + 0x08 => self.num_bytes.load(Ordering::SeqCst), + 0x0C => self.conf.load(Ordering::SeqCst), + 0x10 => self.status.load(Ordering::SeqCst), + 0x14 => { + cpu.binary_memcpy( + self.dst_addr.load(Ordering::SeqCst), + self.src_addr.load(Ordering::SeqCst), + self.num_bytes.load(Ordering::SeqCst), + ); + self.done.store(1, Ordering::SeqCst); + self.next_id.load(Ordering::SeqCst) + } + 0x18 => self.done.load(Ordering::SeqCst), + _ => unimplemented!(), + } + } +} + +#[derive(Default)] +struct MemPoolITA { + config: AtomicU32, + start_address: AtomicU32, + eps_mul_0: AtomicU32, + eps_mul_1: AtomicU32, + right_shift_0: AtomicU32, + right_shift_1: AtomicU32, + add_0: AtomicI32, + add_1: AtomicI32, +} + +impl Peripheral for MemPoolITA { + /// should return the same name as in the config file + fn get_name(&self) -> &'static str { + "mempool-ita" + } + /// store instruction + fn store(&self, cpu: &Cpu, addr: u32, value: u32, _mask: u32, _size: u8) { + match addr { + 0x00 => unsafe { + self.config.store(value as u32, Ordering::SeqCst); + // Out addresses are currently hardcoded in ITA + let out_addresses: [u32; 4] = [0x000C3000, 0x000D3000, 0x000E3000, 0x000F3000]; + let head_config = std::mem::transmute::(value); + let mut return_value = 0; + debug!("[ITA] Store config {:x}", value); + for (i, c) in head_config.iter().enumerate() { + if *c & 0x1 == 1 { + // Start ITA + self.run_ita( + cpu, + self.start_address.load(Ordering::SeqCst), + out_addresses[i], + self.eps_mul_0.load(Ordering::SeqCst), + self.eps_mul_1.load(Ordering::SeqCst), + self.right_shift_0.load(Ordering::SeqCst), + self.right_shift_1.load(Ordering::SeqCst), + self.add_0.load(Ordering::SeqCst), + self.add_1.load(Ordering::SeqCst), + ); + // Set `config` to done + return_value |= 0x1a << (8 * i); + } + } + self.config.store(return_value, Ordering::SeqCst); + debug!("[ITA] Save config {:x}", return_value); + }, + 0x04 => self.start_address.store(value as u32, Ordering::SeqCst), + 0x08 => self.eps_mul_0.store(value, Ordering::SeqCst), + 0x0C => self.eps_mul_1.store(value, Ordering::SeqCst), + 0x10 => self.right_shift_0.store(value, Ordering::SeqCst), + 0x14 => self.right_shift_1.store(value, Ordering::SeqCst), + 0x18 => unsafe {self.add_0.store(std::mem::transmute::(value), Ordering::SeqCst)}, + 0x1C => unsafe {self.add_1.store(std::mem::transmute::(value), Ordering::SeqCst)}, + _ => unimplemented!(), + } + } + /// load instruction + fn load(&self, _cpu: &Cpu, addr: u32, _size: u8) -> u32 { + match addr { + 0x00 => { + let conf = self.config.load(Ordering::SeqCst); + if conf == 0x1a1a1a1a { + self.config.store(0x04040404, Ordering::SeqCst); + } + conf + } + 0x04 => self.start_address.load(Ordering::SeqCst), + 0x08 => self.eps_mul_0.load(Ordering::SeqCst), + 0x0C => self.eps_mul_1.load(Ordering::SeqCst), + 0x10 => self.right_shift_0.load(Ordering::SeqCst), + 0x14 => self.right_shift_1.load(Ordering::SeqCst), + 0x18 => unsafe { std::mem::transmute::(self.add_0.load(Ordering::SeqCst)) }, + 0x1C => unsafe { std::mem::transmute::(self.add_1.load(Ordering::SeqCst)) }, + _ => unimplemented!(), + } + } +} + +impl MemPoolITA { + fn transpose_3d(data: &mut Array3, m: u32, n: u32, p: u32) { + let copy = data.clone(); + for j in 0..m { + for i in 0..n { + for h in 0..p { + data[[j as usize, i as usize, h as usize]] = + copy[[j as usize, h as usize, i as usize]]; + } + } + } + } + + unsafe fn ita_load_2d( + cpu: &Cpu, + data: &mut Array2, + mut address: u32, + m: u32, + n: u32, + splits: u32, + ) { + for split in 0..splits { + for j in 0..m { + for i in (0..n / splits).step_by(4) { + let word = cpu.binary_load(address, 2); + let elements = std::mem::transmute::(word); + for (offset, e) in elements.iter().enumerate() { + data[[j as usize, ((n / splits) * split + i) as usize + offset]] = *e; + } + address += 4; + } + } + } + } + + unsafe fn ita_load_3d( + cpu: &Cpu, + data: &mut Array3, + mut address: u32, + m: u32, + n: u32, + p: u32, + splits: u32, + ) { + for split in 0..splits { + for j in 0..m { + for i in 0..n { + for h in (0..p / splits).step_by(4) { + let word = cpu.binary_load(address, 2); + let elements = std::mem::transmute::(word); + for (offset, e) in elements.iter().enumerate() { + data[[ + j as usize, + i as usize, + ((p / splits) * split + h) as usize + offset, + ]] = *e; + } + address += 4; + } + } + } + } + } + + unsafe fn ita_store_2d( + cpu: &Cpu, + data: &Array2, + address: u32, + m: u32, + n: u32, + splits: u32, + ) { + let mut address_offset = 0; + for split in 0..splits { + for j in 0..m { + for i in (0..n / splits).step_by(4) { + let mut elements = [0u8; 4]; + for offset in 0..elements.len() { + elements[offset] = + data[[j as usize, ((n / splits) * split + i) as usize + offset]] as u8; + } + // let word = std::mem::transmute::<[u8; 4], u32>(elements); + let word = u32::from_ne_bytes(elements); + cpu.binary_store(address + address_offset, word, u32::MAX, 2); + for y in 0..4 { + let _stest = cpu.binary_load(address + address_offset + y, 2); + } + trace!("[ITA] Store OUT to 0x{:x}", address + address_offset); + address_offset += 4; + // if address_offset % 0x100 == 0 { + // address_offset -= 0x0100; + // address_offset += 0x1000; + // } + } + } + } + } + + unsafe fn run_ita( + &self, + cpu: &Cpu, + start_address: u32, + out_address: u32, + _eps_mult_0: u32, + _eps_mult_1: u32, + _right_shift_0: u32, + _right_shift_1: u32, + _add_0: i32, + _add_1: i32, + ) { + // TODO `eps_mult` and `right_shift` are currently hardcoded + // Setup of matrices for query_projection_space_transformation and key_projection_space_transformation + // Sequence of addresses are hardcoded + let start = start_address; + let offset = 64 * 64; + let w4_addr = start + offset * 0; + let w3_addr = start + offset * 1; + let w2_addr = start + offset * 2; + let q_addr = start + offset * 3; + let k_addr = start + offset * 4; + let w1_addr = start + offset * 5; + let b4_addr = start + offset * 6; + let b3_addr = start + offset * 7; + let b2_addr = start + offset * 8; + let b1_addr = start + offset * 9; + + let rqs_mult = u64::to_le_bytes((_eps_mult_1 as u64) << 32 | (_eps_mult_0 as u64)); + let rqs_shift = u64::to_le_bytes((_right_shift_1 as u64) << 32 | (_right_shift_0 as u64)); + let rqs_add = u64::to_le_bytes((_add_1 as u64) << 32 | (_add_0 as u64)).map(|c| c as i8); + + debug!("[ITA] Start Address 0x{:x}, Out Address 0x{:x}", start, out_address); + debug!("[ITA] RQS Mult {:?}", rqs_mult); + debug!("[ITA] RQS Shift {:?}",rqs_shift); + debug!("[ITA] RQS Add {:?}", rqs_add); + + let mut q = Array2::::zeros((64, 64)); + MemPoolITA::ita_load_2d(cpu, &mut q, q_addr, 64, 64, 4); + let mut w_q = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut w_q, w1_addr, 1, 64, 64, 4); + MemPoolITA::transpose_3d(&mut w_q, 1, 64, 64); + + let mut k = Array2::::zeros((64, 64)); + MemPoolITA::ita_load_2d(cpu, &mut k, k_addr, 64, 64, 4); + + let mut w_k = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut w_k, w2_addr, 1, 64, 64, 1); + MemPoolITA::transpose_3d(&mut w_k, 1, 64, 64); + + // Setup of matrices for value_projection_space_transformation + let mut b_v = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut b_v, b3_addr, 1, 64, 64, 4); + MemPoolITA::transpose_3d(&mut b_v, 1, 64, 64); + + let mut v = k.clone(); + let mut w_v = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut w_v, w3_addr, 1, 64, 64, 1); + MemPoolITA::transpose_3d(&mut w_v, 1, 64, 64); + + let mut v_p = Array3::::zeros((1, 64, 64)); + + // matrices in the query_projection_space_transformation + let mut b_q = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut b_q, b1_addr, 1, 64, 64, 4); + let mut q_p = Array3::::zeros((1, 64, 64)); + + // matrices in the key_projection_space_transformation + let mut b_k = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut b_k, b2_addr, 1, 64, 64, 4); + + let mut k_p = Array3::::zeros((1, 64, 64)); + + // matrices in the streaming_partial_softmax + let mut a_requant = Array3::::zeros((1, 64, 64)); + let mut a_partial_softmax = Array2::::zeros((64, 64)); + + // matrices in multi_head_computation + let mut out = Array3::::zeros((1, 64, 64)); + let mut b_o = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut b_o, b4_addr, 1, 64, 64, 4); + let mut w_o = Array3::::zeros((1, 64, 64)); + MemPoolITA::ita_load_3d(cpu, &mut w_o, w4_addr, 1, 64, 64, 1); + MemPoolITA::transpose_3d(&mut w_o, 1, 64, 64); + + // query_projection_space_transformation + // query_projection_space_transformation(&mut q_p, &mut q, &mut w_q, &mut b_q, 1); + MemPoolITA::projection_space_transformation(&mut q_p, &mut q, &mut w_q, &mut b_q, 1); + // requantization of q_p + let mut q_p_requant = Array3::::zeros((1, 64, 64)); + MemPoolITA::requantization_3d(&mut q_p, &mut q_p_requant, rqs_mult[0], rqs_shift[0], rqs_add[0]); + // debug!("q_p_requant: {}", q_p_requant); + + // key_projection_space_transformation + // key_projection_space_transformation(&mut k_p, &mut k, &mut w_k, &mut b_k, 1); + // debug!("k: {}", k); + // debug!("w_k: {}", w_k); + // debug!("b_k: {}", b_k); + MemPoolITA::projection_space_transformation(&mut k_p, &mut k, &mut w_k, &mut b_k, 1); + // requantization of k_p + let mut k_p_requant = Array3::::zeros((1, 64, 64)); + MemPoolITA::requantization_3d(&mut k_p, &mut k_p_requant, rqs_mult[1], rqs_shift[1], rqs_add[1]); + // debug!("k_p_requant: {}", k_p_requant); + + // query_key_correlation + let mut qk = Array3::::zeros((1, 64, 64)); + MemPoolITA::query_key_correlation(&mut q_p_requant, &mut k_p_requant, &mut qk); + // requantization of qk + MemPoolITA::requantization_3d(&mut qk, &mut a_requant, rqs_mult[2], rqs_shift[2], rqs_add[2]); + // debug!("a_requant: {}", a_requant); + + // streaming_partial_softmax + MemPoolITA::streaming_partial_softmax(&mut a_requant, &mut a_partial_softmax, 64); + + // value_projection_space_transformation + // value_projection_space_transformation(&mut v_p, &mut v, &mut w_v, &mut b_v, 1); + MemPoolITA::projection_space_transformation(&mut v_p, &mut v, &mut w_v, &mut b_v, 1); + // requantization of v_p + let mut v_p_requant = Array3::::zeros((1, 64, 64)); + MemPoolITA::requantization_3d(&mut v_p, &mut v_p_requant, rqs_mult[3], rqs_shift[3], rqs_add[3]); + // debug!("v_p_requant: {}", v_p_requant); + + // single_head_computation + let mut o_softmax = Array3::::zeros((1, 64, 64)); + MemPoolITA::single_head_computation( + &mut a_partial_softmax, + &mut v_p_requant, + &mut o_softmax, + ); + // requantization of o_softmax + let mut o_softmax_requant = Array3::::zeros((1, 64, 64)); + MemPoolITA::requantization_3d(&mut o_softmax, &mut o_softmax_requant, rqs_mult[4], rqs_shift[4], rqs_add[4]); + // debug!("o_softmax_requant: {}", o_softmax_requant); + + // multi_head_computation + MemPoolITA::multi_head_computation(&mut o_softmax_requant, &mut out, &mut w_o, &mut b_o, 1); + // parallel requantization of out + let mut out_requant = Array2::::zeros((64, 64)); + MemPoolITA::parallel_requantize3d(&mut out, &mut out_requant, rqs_mult[5], rqs_shift[5], rqs_add[5]); + // debug!("out_requant: {}", out_requant); + + // for j in 0..out_requant.shape()[1] { + // let row = out_requant.slice(s![j, ..]); + // debug!("out[{},:]:\n{}", j, row); + // } + + // Store the output + MemPoolITA::ita_store_2d(cpu, &out_requant, out_address, 64, 64, 1); + } + + fn requantize_row(element: i32, eps_mult: u8, right_shift: u8, add: i8) -> i8 { + let mut shifted = ((element * (eps_mult as i32)) >> (right_shift as i32)) + (add as i32); + + // Perform rounding half away from zero + if right_shift > 0 && ((element * (eps_mult as i32)) >> ((right_shift-1) as i32)) & 0x1 == 1 { + shifted = shifted.saturating_add(1); + } + if shifted > 127 { + return 127; + } else if shifted < -128 { + return -128; + } else { + return shifted as i8; + } + } + + fn requantization_3d( + m: &mut Array3, + m_requant: &mut Array3, + eps_mult: u8, + right_shift: u8, + add: i8, + ) { + // debug!("===================== 3D Requantization ====================="); + + // Loop over the number of heads + for i in 0..m.shape()[0] { + // Loop over the head dimension + for j in 0..m.shape()[1] { + // print the column of the head matrix + let row = m.slice(s![i, j, ..]); + // Iterate over the row and requantize it + for k in 0..row.len() { + m_requant[[i, j, k]] = + MemPoolITA::requantize_row(row[k], eps_mult, right_shift, add); + } + } + } + } + + fn parallel_requantize3d( + m: &mut Array3, + m_requant: &mut Array2, + eps_mult: u8, + right_shift: u8, + add: i8, + ) { + // debug!("===================== Parallel 3D Requantization ====================="); + m_requant.fill(add); + for i in 0..m.shape()[0] { + for j in 0..m.shape()[1] { + let row = m.slice(s![i, j, ..]); + for k in 0..row.len() { + let mut shifted = ((row[k] * (eps_mult as i32)) >> (right_shift as i32)) + + m_requant[[i * m.shape()[1] + j, k]] as i32; + + // Perform rounding half away from zero + if right_shift > 0 && ((row[k] * (eps_mult as i32)) >> ((right_shift-1) as i32)) & 0x1 == 1 { + shifted = shifted.saturating_add(1); + } + m_requant[[i * m.shape()[1] + j, k]] = + MemPoolITA::requantize_row(shifted, 1, 0, 0); + } + } + } + } + + fn projection_space_transformation( + p: &mut Array3, + m: &mut Array2, + w: &mut Array3, + b: &mut Array3, + bias: u8, + ) { + // debug!("===================== Projection Space Transformation ====================="); + if bias == 1 { + for i in 0..p.shape()[0] { + for j in 0..p.shape()[1] { + for k in 0..p.shape()[2] { + p[[i, j, k]] = b[[i, j, k]] as i32; + for l in 0..m.shape()[1] { + p[[i, j, k]] += m[[j, l]] as i32 * w[[i, l, k]] as i32; + } + } + } + } + } else { + for i in 0..p.shape()[0] { + for j in 0..p.shape()[1] { + for k in 0..p.shape()[2] { + p[[i, j, k]] = 0; + for l in 0..m.shape()[1] { + p[[i, j, k]] += m[[j, l]] as i32 * w[[i, l, k]] as i32; + } + } + } + } + } + + // debug!("projected matrix: {:?}", p); + } + + fn query_key_correlation( + qp_requant: &mut Array3, + kp_requant: &mut Array3, + qk: &mut Array3, + ) { + // debug!("===================== Query Key Correlation ====================="); + + // Loop over the number of heads + for i in 0..qk.shape()[0] { + // Loop over the number of queries + for j in 0..qk.shape()[1] { + // Loop over the number of keys + for k in 0..qk.shape()[2] { + qk[[i, j, k]] = 0; + // Loop over the number of features + for l in 0..qk.shape()[1] { + qk[[i, j, k]] += qp_requant[[i, j, l as usize]] as i32 + * kp_requant[[i, k, l as usize]] as i32; + } + } + } + } + + // debug!("qk: {:?}", qk); + } + + //Compute the approximated softmax function. + fn streaming_partial_softmax( + a_requant: &mut Array3, + a_partial_softmax: &mut Array2, + seq_len: i32, + ) { + // debug!("===================== Streaming Partial SoftMax ====================="); + + // let log2e: f64 = f64::log2(f64::exp(1.0)); + // let b = 8; + // let eps_x = b as f64 / (2.0f64.powi(b) * log2e); + let mut exp_partial_sum = Array1::::zeros(seq_len as usize); + let mut max = Array1::::zeros(64); + let mut current_max = Array1::::zeros(64); + + for i in 0..4 { + let a_requant_slice = a_requant.slice_mut(s![0, .., i * 16..(i + 1) * 16]); + + for n in 0..a_requant_slice.nrows() { + current_max[[n]] = a_requant_slice.row(n).iter().copied().max().unwrap() as i8; + } + + for j in 0..seq_len { + let mut shift_sum; + if i == 0 || current_max[j as usize] > max[[j as usize]] { + if i == 0 { + shift_sum = 0; + } else { + shift_sum = (current_max[j as usize] - max[[j as usize]]) / 32; + // if (((current_max[j as usize] - max[[j as usize]]) / 32) - shift_sum) as f64 + // >= 0.5 + // { + // shift_sum += 1; + // } + let shift_int = + (current_max[j as usize] as i32) - (max[[j as usize]] as i32); + if shift_int % 32 >= 16 { + shift_sum += 1; + } + } + max[j as usize] = current_max[j as usize]; + } else { + shift_sum = 0; + } + + let qb = a_requant + .slice_mut(s![0, .., i * 16..(i + 1) * 16]) + .mapv(|x| x as i32 - max[[j as usize]] as i32); + + let mut qexp = 0; + for k in 0..qb.ncols() { + let mut shift = (-qb[[j as usize, k]]) as i32 / 32; + let shift_int = (-qb[[j as usize, k]]) as i32; + + if shift_int % 32 >= 16 { + shift += 1; + } + + qexp += (2_u32.pow(10) >> shift as i32) as i32; + } + + exp_partial_sum[[j as usize]] = + (exp_partial_sum[[j as usize]] >> shift_sum as i32) + qexp; + } + } + for j in 0..seq_len { + let factor = + ((2.0f64.powi(8) - 1.0) * 2.0f64.powi(10)) as i32 / exp_partial_sum[j as usize]; + for k in 0..seq_len { + let mut shift = (((max[j as usize] as i32) + - (a_requant[[0, j as usize, k as usize]] as i32)) + / 32) as i32; + let shift_int = + (max[j as usize] as i32) - (a_requant[[0, j as usize, k as usize]] as i32); + if shift_int % 32 >= 16 { + shift += 1; + } + a_partial_softmax[[j as usize, k as usize]] = + (factor as i32) / 2.0f64.powi(shift) as i32; + } + } + + // debug!("a_partial_softmax: {}", a_partial_softmax); + } + + fn single_head_computation( + a_partial_softmax: &mut Array2, + vp_requant: &mut Array3, + o_softmax: &mut Array3, + ) { + // debug!("===================== Single Head Computation ====================="); + + // Loop over the number of heads + for i in 0..o_softmax.shape()[0] { + // Loop over the number of queries + for j in 0..o_softmax.shape()[1] { + // Loop over the number of keys + for k in 0..o_softmax.shape()[2] { + o_softmax[[i, j, k]] = 0; + // Loop over the number of features + for l in 0..o_softmax.shape()[1] { + o_softmax[[i, j, k]] += + a_partial_softmax[[j, l]] as i32 * vp_requant[[i, l, k]] as i32; + } + } + } + } + + // debug!("o_softmax: {:?}", o_softmax); + } + + fn multi_head_computation( + o_softmax_requant: &mut Array3, + out: &mut Array3, + w_o: &mut Array3, + b_o: &mut Array3, + bias: u8, + ) { + // debug!("===================== Multi Head Computation ====================="); + + if bias == 1 { + for i in 0..out.shape()[0] { + for j in 0..out.shape()[1] { + for k in 0..out.shape()[2] { + out[[i, j, k]] = b_o[[i, j, k]] as i32; + for l in 0..out.shape()[1] { + out[[i, j, k]] += + o_softmax_requant[[i, j, l]] as i32 * w_o[[i, l, k]] as i32; + } + } + } + } + } else { + for i in 0..out.shape()[0] { + for j in 0..out.shape()[1] { + for k in 0..out.shape()[2] { + out[[i, j, k]] = 0; + for l in 0..out.shape()[1] { + out[[i, j, k]] += + o_softmax_requant[[i, j, l]] as i32 * w_o[[i, l, k]] as i32; + } + } + } + } + } + + // debug!("out: {:?}", out); + } +}