-
Notifications
You must be signed in to change notification settings - Fork 32
/
matrix.rs
130 lines (118 loc) · 3.51 KB
/
matrix.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::ops::{Index, IndexMut};
pub trait MatrixTrait:
Sized + Index<(usize, usize), Output = Self::Scalar> + IndexMut<(usize, usize)>
{
type Scalar;
fn ptr(&self) -> *const Self::Scalar;
fn mut_ptr(&mut self) -> *mut Self::Scalar;
fn as_slice(&self) -> &[Self::Scalar];
fn as_mut_slice(&mut self) -> &mut [Self::Scalar];
fn change_shape(&self) -> Self;
fn change_shape_mut(&mut self);
fn spread(&self) -> String;
fn col(&self, index: usize) -> Vec<Self::Scalar>;
fn row(&self, index: usize) -> Vec<Self::Scalar>;
fn diag(&self) -> Vec<Self::Scalar>;
fn transpose(&self) -> Self;
fn t(&self) -> Self {
self.transpose()
}
fn subs_col(&mut self, idx: usize, v: &[Self::Scalar]);
fn subs_row(&mut self, idx: usize, v: &[Self::Scalar]);
fn from_index<F, G>(f: F, size: (usize, usize)) -> Self
where
F: Fn(usize, usize) -> G + Copy,
G: Into<Self::Scalar>;
fn to_vec(&self) -> Vec<Vec<Self::Scalar>>;
fn to_diag(&self) -> Self;
fn submat(&self, start: (usize, usize), end: (usize, usize)) -> Self;
fn subs_mat(&mut self, start: (usize, usize), end: (usize, usize), m: &Self);
}
// ┌─────────────────────────────────────────────────────────┐
// For Linear Algebra
// └─────────────────────────────────────────────────────────┘
/// Linear algebra trait
pub trait LinearAlgebra<M: MatrixTrait> {
fn back_subs(&self, b: &[M::Scalar]) -> Vec<M::Scalar>;
fn forward_subs(&self, b: &[M::Scalar]) -> Vec<M::Scalar>;
fn lu(&self) -> PQLU<M>;
fn waz(&self, d_form: Form) -> Option<WAZD<M>>;
fn qr(&self) -> QR<M>;
fn svd(&self) -> SVD<M>;
#[cfg(feature = "O3")]
fn cholesky(&self, uplo: UPLO) -> M;
fn rref(&self) -> M;
fn det(&self) -> M::Scalar;
fn block(&self) -> (M, M, M, M);
fn inv(&self) -> M;
fn pseudo_inv(&self) -> M;
fn solve(&self, b: &[M::Scalar], sk: SolveKind) -> Vec<M::Scalar>;
fn solve_mat(&self, m: &M, sk: SolveKind) -> M;
fn is_symmetric(&self) -> bool;
}
#[allow(non_snake_case)]
pub fn solve<M: MatrixTrait + LinearAlgebra<M>>(A: &M, b: &M, sk: SolveKind) -> M {
A.solve_mat(b, sk)
}
/// Data structure for Complete Pivoting LU decomposition
///
/// # Usage
/// ```rust
/// use peroxide::fuga::*;
///
/// let a = ml_matrix("1 2;3 4");
/// let pqlu = a.lu();
/// let (p, q, l, u) = pqlu.extract();
/// // p, q are permutations
/// // l, u are matrices
/// l.print(); // lower triangular
/// u.print(); // upper triangular
/// ```
#[derive(Debug, Clone)]
pub struct PQLU<M: MatrixTrait> {
pub p: Vec<usize>,
pub q: Vec<usize>,
pub l: M,
pub u: M,
}
#[derive(Debug, Clone)]
pub struct WAZD<M: MatrixTrait> {
pub w: M,
pub z: M,
pub d: M,
}
#[derive(Debug, Clone)]
pub struct QR<M: MatrixTrait> {
pub q: M,
pub r: M,
}
#[derive(Debug, Copy, Clone)]
pub enum Form {
Diagonal,
Identity,
}
#[derive(Debug, Copy, Clone)]
pub enum SolveKind {
LU,
WAZ,
}
#[allow(non_camel_case_types)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum UPLO {
Upper,
Lower,
}
impl<M: MatrixTrait> QR<M> {
pub fn q(&self) -> &M {
&self.q
}
pub fn r(&self) -> &M {
&self.r
}
}
#[derive(Debug, Clone)]
pub struct SVD<M: MatrixTrait> {
pub s: Vec<f64>,
pub u: M,
pub vt: M,
}