Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【新算子】- linalg.lu 算子开发 #1007

Open
PetrelYy opened this issue Apr 19, 2024 · 20 comments
Open

【新算子】- linalg.lu 算子开发 #1007

PetrelYy opened this issue Apr 19, 2024 · 20 comments
Assignees
Labels
ICT New Op Contribute a new operator

Comments

@PetrelYy
Copy link
Collaborator

开发计划可参考以下节点:

  1. 方案撰写,xx.xx~xx.xx
  2. 开发自测,xx.xx~xx.xx
  3. 提出 PR/MR,xx.xx~xx.xx
  4. review( 3个赞),xx.xx~xx.xx
  5. maintainer 合入
@Chuancysun Chuancysun self-assigned this Apr 19, 2024
@PetrelYy PetrelYy added New Op Contribute a new operator ICT and removed ICT labels Apr 23, 2024
@PetrelYy
Copy link
Collaborator Author

@Chuancysun 麻烦更新进展

@Chuancysun
Copy link
Collaborator

目前正针对长条形的矩阵规模,比如(65536,30)进行针对性的优化,重构了分解内核的代码逻辑,其他规模下能够达到10倍以内的性能指标

@Chuancysun
Copy link
Collaborator

目前正针对长条形的矩阵规模,比如(65536,30)进行针对性的优化,重构了分解内核的代码逻辑,其他规模下能够达到10倍以内的性能指标

补充:在优化了自己实现的矩阵乘加中的加法算子后,部分规模的矩阵性能仍未达标,用性能分析工具分析后发现瓶颈在最内层的分解内核,目前重构了分解内核的代码逻辑,通过性能分析工具发现在分解内核中有大量重复的 GDRAM 与片上之间的数据搬移,优化后的思路把所有高频使用的数据保存在片上,并采取计算换访存,如果片上存储空间不足则分段存入,预计性能有可观的提升。
1

@Chuancysun
Copy link
Collaborator

PR链接如下:
#1019

@Chuancysun
Copy link
Collaborator

schedule
工作计划如图

@Chuancysun
Copy link
Collaborator

当前完成了实数单batch非主元的LU分解,文档和代码已经贴到了PR,#1019

@Chuancysun
Copy link
Collaborator

当前完成了实数多batch非主元的LU分解,性能如图,其中(4,5,65536,3000)规模MLU和Magma均无法开启足够的内存(大约15G),且cusolver中没有LU算子的batch实现接口,所以这里将Magma作为比较对象。
1

@Chuancysun
Copy link
Collaborator

当前完成了实数多batch非主元的LU分解,性能如图,其中(4,5,65536,3000)规模MLU和Magma均无法开启足够的内存(大约15G),且cusolver中没有LU算子的batch实现接口,所以这里将Magma作为比较对象。 1

代码稍后整理后会贴出PR

@PetrelYy
Copy link
Collaborator Author

PetrelYy commented Jun 12, 2024

@Chuancysun 沟通,当前PR 中设计文档&代码仅包含非主元,主元文档以及代码还在开发中

@PetrelYy
Copy link
Collaborator Author

建议排期提前半周,多留几天给review +修改。 否则7.15 风险很大

@Chuancysun
Copy link
Collaborator

schedule
更新后的工作计划如图,目前正在调试复数单batch的性能,非主元部分的性能优化预计会如期完成,在非主元的优化经验基础上,选主元部分的性能调试预计会比较快,但具体情况需要实现后才能判断;

@Chuancysun
Copy link
Collaborator

正在完成单batch复数下的功能及性能调优,目前正在重点优化长条规模下的性能。

@Chuancysun
Copy link
Collaborator

完成了复数单batch和多batch的正确性测试,性能测试结果如图,正在针对性的优化长条形矩阵的规模
2
1

@Chuancysun
Copy link
Collaborator

完成了对长条形矩阵的性能优化,目前已完成非主元的LU分解,正在开发选主元的LU分解

@Chuancysun
Copy link
Collaborator

目前完成了选主元LU分解中较小规模的功能和性能,对于较大规模的分块实现正在开发调试中

@Chuancysun
Copy link
Collaborator

4ZRSROKN1S@ZV6AILM0UB
如图,kernel里对nram和sram的max_size进行测试,但是发现nram和sram分别到大约512k和2048k左右就开始报错显示内存超限,根据文档的描述应该是640k和4096k大小,报错如下:
WVS O$1O _4FGV)Y~01B9WR

@PetrelYy
Copy link
Collaborator Author

PetrelYy commented Jul 16, 2024

nram float uint8_t test[512*1024]; /// 这句代码错误,有两个类型

不建议编写上述代码,因为不同板卡NRAM_SIZE 大小存在区别,590 算子开发可用 nram 空间没有512k

@ArtIntAI
Copy link
Collaborator

ArtIntAI commented Sep 4, 2024

测试代码和json可以贴下

@Chuancysun
Copy link
Collaborator

mannul_shape_1.json
测试用例可以参考如上

@Chuancysun
Copy link
Collaborator

compute.py如下:
import torch
from nonmlu_ops.base import *
import logging
import copy
import os

@registerTensorList("sgetrf2")
class sgetrf2TensorList(TensorList):
pass
# def castDataNode(self):
# '''
# cast input data to onchip data by input_dtype and input_onchip_dtype.
# '''
# # compute baseline output
# # Qcast Force Cast FFT
# # x_fp -----------> x_int --------------> x_fp -------> y_fp
# for input_tensor in self.input_tensors_:
# input_datanode = input_tensor.getDataNode()
# input_onchip_datanode = input_tensor.onchip_datanode_
# if input_onchip_datanode.dtype_.isQuantType():
# bitnum = input_onchip_datanode.dtype_.getDataBits()
# if_scale = input_tensor.if_scale_
# if_offset = input_tensor.if_offset_
# if input_datanode.dtype_.isFloatPoint():
# # has nan or has inf, return
# if np.isnan(input_datanode.data_).any() or np.isinf(input_datanode.data_).any():
# return
# # Qcast
# position, scale, offset = quantize_utils.compute_quant_param(input_datanode.data_, bitnum, if_scale, if_offset)
# input_onchip_datanode.setData(quantize_utils.float2fix(input_datanode.data_, bitnum, position, scale, offset))
# input_onchip_datanode.setQuantParam(position, scale, offset)
# # Force cast
# input_datanode.setData(quantize_utils.fix2float(bitnum, input_onchip_datanode.data_, position, scale, offset))
# elif input_datanode.dtype_.isComplex():
# real_data, imag_data = input_datanode.getComplexData()
# complex_data = np.concatenate((real_data, imag_data), axis=-1)
# # has nan or has inf, return
# if np.isnan(complex_data).any() or np.isinf(complex_data).any():
# return
# # Qcast
# position, scale, offset = quantize_utils.compute_quant_param(complex_data, bitnum, if_scale, if_offset)
# real_quant_data = quantize_utils.float2fix(real_data, bitnum, position, scale, offset)
# imag_quant_data = quantize_utils.float2fix(imag_data, bitnum, position, scale, offset)
# input_onchip_datanode.setQuantParam(position, scale, offset)
# # Force cast
# real_dequant_data = quantize_utils.fix2float(bitnum, real_quant_data, position, scale, offset)
# imag_dequant_data = quantize_utils.fix2float(bitnum, imag_quant_data, position, scale, offset)
# input_datanode.setComplexData(real_dequant_data, imag_dequant_data)

def print_matrix(A):
if A.ndim == 3:
batch = A.shape[0]
size = A.shape[1]
for i in range(batch):
for j in range(size):
for k in range(size):
print("{:.3}".format(A[i][j][k]),end=" ")
print("\n")
print("\n")
elif A.ndim == 2:
size = A.shape[0]
for i in range(size):
for j in range(size):
print("{:.3}".format(A[i][j]),end=" ")
print("\n")
elif A.ndim == 1:
size = A.shape[0]
for i in range(size):
print("{}".format(A[i]), end=" ")
print("\n")

def set_complex_data(data_node, complex_tensor):
cpu_array = complex_tensor.cpu().numpy()
cpu_real = np.real(cpu_array)
cpu_imag = np.imag(cpu_array)
data_node.setComplexData(cpu_real, cpu_imag)

def set_values_below_threshold(input_tensor, threshold=1e-3, new_value=1e-6):
# 获取数据类型
dtype = input_tensor.dtype
if dtype == torch.float32 or dtype == torch.complex64:
new_value_pos = torch.tensor(new_value, dtype=torch.float32, device=input_tensor.device)
new_value_neg = torch.tensor(-new_value, dtype=torch.float32, device=input_tensor.device)
elif dtype == torch.float64 or dtype == torch.complex128:
new_value_pos = torch.tensor(new_value, dtype=torch.float64, device=input_tensor.device)
new_value_neg = torch.tensor(-new_value, dtype=torch.float64, device=input_tensor.device)
else:
raise ValueError("Unsupported tensor dtype")

# 对于复数tensor,分别处理实部和虚部
if torch.is_complex(input_tensor):
    real_part = input_tensor.real
    imag_part = input_tensor.imag

    real_part[(real_part.abs() < threshold) & (real_part >= 0)] = new_value_pos
    real_part[(real_part.abs() < threshold) & (real_part < 0)] = new_value_pos
    imag_part[(imag_part.abs() < threshold) & (imag_part >= 0)] = new_value_pos
    imag_part[(imag_part.abs() < threshold) & (imag_part < 0)] = new_value_pos

    input_tensor = torch.complex(real_part, imag_part)
else:
    # 对于非复数tensor
    input_tensor[(input_tensor.abs() < threshold) & (input_tensor >= 0)] = new_value_pos
    input_tensor[(input_tensor.abs() < threshold) & (input_tensor < 0)] = new_value_pos

return input_tensor

def set_diag_imag_one(input_tensor):
if input_tensor.dim() == 2:
diag_indices = torch.arange(input_tensor.size(0), device=input_tensor.device)
input_tensor[diag_indices, diag_indices] += 1j - input_tensor[diag_indices, diag_indices].imag * 1j

elif input_tensor.dim() == 3:
    batch_size, n, _ = input_tensor.size()
    for i in range(batch_size):
        diag_indices = torch.arange(n, device=input_tensor.device)
        input_tensor[i, diag_indices, diag_indices] += 1j - input_tensor[i, diag_indices, diag_indices].imag * 1j

def matrix_multiply(A, B):
# 获取矩阵的维度
rows_A, cols_A = A.shape
rows_B, cols_B = B.shape

# 检查矩阵维度是否匹配
if cols_A != rows_B:
    raise ValueError("矩阵A的列数必须等于矩阵B的行数")

# 创建结果矩阵C,初始化为零
C = torch.zeros((rows_A, cols_B))

# 三重循环实现矩阵相乘
for i in range(rows_A):
    for j in range(cols_B):
        for k in range(cols_A):
            C[i][j] += A[i][k] * B[k][j]

return C

def extract_LU(LU, pivots):
if LU.dim() == 2:
# 处理单个矩阵的情况
m, n = LU.size()
if torch.is_complex(LU):
L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device) * (1 + 0j)
else:
L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device)
U = torch.triu(LU)
if m < n:
L = L[:, :m] # 裁剪为m * m
elif m > n:
U = U[:n, :]

elif LU.dim() == 3:
    # 处理多个矩阵的批次情况
    batch_size, m, n = LU.size()
    if torch.is_complex(LU):
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1) * (1 + 0j)
    else:
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1)
    U = torch.triu(LU)
    if m < n:
        L = L[:, :, :m]  # 裁剪为batch * m * m
    elif m > n:
        U = U[:, :n, :]

elif LU.dim() == 4:
    # 降维
    batch_size, depth, m, n = LU.size()
    LU = LU.view(batch_size * depth, m, n)
    # 处理多个矩阵的批次情况
    batch_size, m, n = LU.size()
    if torch.is_complex(LU):
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1) * (1 + 0j)
    else:
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1)
    U = torch.triu(LU)
    if m < n:
        L = L[:, :, :m]  # 裁剪为batch * m * m
    elif m > n:
        U = U[:, :n, :]

else:
    raise ValueError("Unsupported number of dimensions for LU tensor")

return L, U

def make_diagonally_dominant(input_data):
if input_data.dim() == 2:
# 单个矩阵的情况
m, n = input_data.size()
min_mn = min(m, n)
for i in range(min_mn):
if torch.is_complex(input_data):
# 处理复数矩阵
real_part = input_data[i, i].real
imag_part = input_data[i, i].imag
input_data[i, i] = torch.complex(real_part + n, imag_part)
else:
# 处理实数矩阵
input_data[i, i] += n
elif input_data.dim() == 3:
# 多个矩阵的批次情况
batchCount, m, n = input_data.size()
min_mn = min(m, n)
for s in range(batchCount):
for i in range(min_mn):
if torch.is_complex(input_data):
# 处理复数矩阵
real_part = input_data[s, i, i].real
imag_part = input_data[s, i, i].imag
input_data[s, i, i] = torch.complex(real_part + n, imag_part)
else:
# 处理实数矩阵
input_data[s, i, i] += n
elif input_data.dim() == 4:
# 降维
batch_size, depth, m, n = input_data.size()
input_data = input_data.view(batch_size * depth, m, n)
# 多个矩阵的批次情况
batchCount, m, n = input_data.size()
min_mn = min(m, n)
for s in range(batchCount):
for i in range(min_mn):
if torch.is_complex(input_data):
# 处理复数矩阵
real_part = input_data[s, i, i].real
imag_part = input_data[s, i, i].imag
input_data[s, i, i] = torch.complex(real_part + n, imag_part)
else:
# 处理实数矩阵
input_data[s, i, i] += n
return input_data

Function to swap two rows of a matrix

def swap_rows(matrix, row1, row2):
# print("swap row1 row2", row1,row2)
# print("row1 ")
# size = matrix[row1-1,:].shape[0]
# for i in range(size):
# print("{:.3}".format(matrix[row1-1,i]), end=" ")
# print("row2 ")
# size = matrix[row2-1,:].shape[0]
# for i in range(size):
# print("{:.3}".format(matrix[row2-1,i]), end=" ")

matrix[[row1-1, row2-1], :] = matrix[[row2-1, row1-1], :]

Function to apply row swaps to matrix A using ipiv

def apply_row_swaps(A, ipiv):
if A.dim() == 2:
batch_size = 1
m, n = A.size()
elif A.dim() == 3:
batch_size, m, n = A.size()
elif A.dim() == 4:
batch_size, depth, m, n = A.size()
batch_size = batch_size * depth
else:
raise ValueError("Unsupported number of dimensions for A tensor")
m = min(m, n)
if batch_size > 1:
for b in range(batch_size):
for i in range(m - 1, -1, -1): # Iterate backwards
if ipiv[b, i] - 1 != i:
swap_rows(A[b], i + 1, ipiv[b, i])
# ipiv[b, ipiv[b, i] - 1], ipiv[b, i] = ipiv[b, i], ipiv[b, ipiv[b, i] - 1]
else:
for i in range(m - 1, -1, -1): # Iterate backwards
if ipiv[i] - 1 != i:
# print("i",i)
# print("ipiv[i] ",ipiv[i])
swap_rows(A, i + 1, ipiv[i])
# ipiv[ipiv[i] - 1], ipiv[i] = ipiv[i], ipiv[ipiv[i] - 1]

@registerOp("sgetrf2")
class sgetrf2Op(OpTest):
def init(self,tensorlist,params):
super().init(tensorlist,params)
self.mode_ = self.params_.get("mode")

compute_cout = 0

def compute(self):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
    gpu_count = torch.cuda.device_count()
    print("gpu_count:")
    print(torch.cuda.device_count())
    print("os visible:",os.environ.get('CUDA_VISIBLE_DEVICES'))
    cuda_count = sgetrf2Op.compute_cout % gpu_count
    print('now gpu:',cuda_count)
    sgetrf2Op.compute_cout += 1

    result_mul = True

    input_tensor = self.tensor_list_.getInputTensor(0)
    output_tensor = self.tensor_list_.getOutputTensor(0)
    input_is_complex = input_tensor.getDataType().isComplex()
    mode = self.mode_
    

    if not input_is_complex:
        # input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
        # input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
        # upper_triangle = torch.triu(input_data, diagonal=1)
        # L_matrix = input_data - upper_triangle
        
        # del input_data
        # del upper_triangle
        # torch.cuda.empty_cache()
        
        # batch = 1
        # size = L_matrix.size(1)
        # if L_matrix.dim() == 2:
        #     U_matrix = L_matrix.transpose(0, 1)
        #     A = torch.mm(L_matrix, U_matrix)
        #     del L_matrix
        #     torch.cuda.empty_cache()
        #     eye = torch.eye(size, dtype=torch.float32).cuda(cuda_count)
        #     A = A + eye
        # elif L_matrix.dim() == 3:
        #     U_matrix = L_matrix.transpose(1, 2)
        #     A = torch.bmm(L_matrix, U_matrix)
        #     batch = L_matrix.size(0)
        #     del L_matrix
        #     torch.cuda.empty_cache()
        #     eye = torch.eye(size, dtype=torch.float32).expand(batch, -1, -1).cuda(cuda_count)
        #     A = A + eye
        # else:
        #     exit()
        
        # del eye
        # del U_matrix
        # torch.cuda.empty_cache()
        
        # 输入矩阵 A
        flag = (mode == 1)
        print("pivot ",flag)
        
        input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
        if flag == False:
            input_data = make_diagonally_dominant(input_data)
        input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
        input_tensor.setData(input_data.cpu().numpy())
        
        # print("input:")
        # print_matrix(input_data.cpu().numpy())

        # input_tensor.setData(input_data)
        torch.backends.cuda.preferred_linalg_library(backend='cusolver')
        LU, pivots = torch.linalg.lu_factor(input_data, pivot=flag)
        # print("LU ",LU)
        
        # batch = 1
        # size = L_matrix.size(1)

        L_matrix, U_matrix = extract_LU(LU, pivots)
        print("L U size",L_matrix.size(),U_matrix.size())
        # print("L",L_matrix)
        # print("U",U_matrix)
        if result_mul or mode == 1:
            if L_matrix.dim() == 2:
                result = torch.mm(L_matrix, U_matrix)
            else:
                result = torch.bmm(L_matrix, U_matrix)
            
        else:
            result = LU
            
        # print("ipiv fp32")
        # print_matrix(pivots.cpu().numpy())
        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)
            
        # print("result fp32 LU:")
        # print_matrix(result.cpu().numpy())
        
        output_result = result.cpu().numpy()
        output_tensor.setData(output_result)
        
        del LU
        del pivots
        del result
        del L_matrix
        del U_matrix
        del output_result
        del input_data
        torch.cuda.empty_cache()
        

        # upper_triangle_fp64 = torch.triu(input_data_fp64, diagonal=1)
        # L_matrix_fp64 = input_data_fp64 - upper_triangle_fp64
        # del upper_triangle_fp64
        # del input_data_fp64
        # torch.cuda.empty_cache()
        
        # if L_matrix_fp64.dim() == 2:
        #     U_matrix_fp64 = L_matrix_fp64.transpose(0, 1)
        #     A_fp64 = torch.mm(L_matrix_fp64, U_matrix_fp64)
        #     del L_matrix_fp64
        #     del U_matrix_fp64
        #     A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).cuda(cuda_count)
        # elif L_matrix_fp64.dim() == 3:
        #     U_matrix_fp64 = L_matrix_fp64.transpose(1, 2)
        #     A_fp64 = torch.bmm(L_matrix_fp64, U_matrix_fp64)
        #     del L_matrix_fp64
        #     del U_matrix_fp64
        #     A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).expand(batch, -1, -1).cuda(cuda_count)
        
        # torch.cuda.empty_cache()
        A_fp64 = input_data_fp64.double()
        
        

        result_LU_fp64, pivots = torch.linalg.lu_factor(A_fp64, pivot=flag)
        
        result_L_fp64, result_U_fp64 = extract_LU(result_LU_fp64, pivots)
        
        # del A_fp64
        torch.cuda.empty_cache()


        base_node = DataNode("double")

        if result_mul or mode == 1:
            if result_L_fp64.dim() == 2:
                result = torch.matmul(result_L_fp64, result_U_fp64)
            else:
                result = torch.bmm(result_L_fp64, result_U_fp64)

            
        else:
            result = result_LU_fp64
                
        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)

            # print("orign fp64 result:")
            # print_matrix(result_L_fp64.cpu().numpy())
            
            # result = result_L_fp64
        # print("result fp64 LU:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv fp64")
        # print_matrix(pivots.cpu().numpy())
        output_result = result.cpu().numpy()
        base_node.setData(output_result)
        del result
        del pivots
        del result_L_fp64
        del result_U_fp64
        del result_LU_fp64
        del input_data_fp64
        del A_fp64
        del output_result
        torch.cuda.empty_cache()


        # print_matrix(output_result_fp64)

        
        half_dynamic_threshold = 1e-3
        float_dynamic_threshold = 1e-5
        eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
        diff1 = eva.computeDiff1()
        diff2 = eva.computeDiff2()
        diff_3_2 = eva.computeDiff3_2(10.0)
        print("diff1: ", diff1)
        print("diff2: ", diff2)
        print("diff_3_2: ", diff_3_2)
        output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

    else:
        torch.backends.cuda.preferred_linalg_library(backend='cusolver')
        flag = (mode == 1)
        print("pivot ",flag)
        input_real_data, input_imag_data = input_tensor.getComplexData()
        # 组合成复数张量
        input_complex_data = torch.complex(torch.from_numpy(input_real_data), torch.from_numpy(input_imag_data))
        # upper_triangle_real = np.tril(input_real_data)
        # upper_triangle_imag = np.tril(input_imag_data, k=-1)
        # complex_numpy_array = upper_triangle_real + 1j * upper_triangle_imag
        # del upper_triangle_real
        # del upper_triangle_imag
        
        input_data = torch.tensor(input_complex_data, dtype=torch.complex64).cuda(cuda_count)
        if flag== False:
            input_data = make_diagonally_dominant(input_data)
        set_complex_data(input_tensor, input_data)
        # del complex_numpy_array
        # print("origin input:")
        # print_matrix(input_data.cpu().numpy())
        # input_real_data = np.expand_dims(input_real_data, -1)
        # print("real data:",input_real_data)
        # input_imag_data = np.expand_dims(input_imag_data, -1)
        # print("imag data:",input_imag_data)
        # input_complex_data = np.concatenate((input_real_data, input_imag_data), axis=-1)

        # input_data = torch.tensor(input_complex_data).cuda(cuda_count)

        # print("complex data:",input_data)


        input_data_complex128 = input_data.type(torch.complex128).cuda(cuda_count)
        # upper_triangle_complex64 = torch.triu(input_data, diagonal=1)
        # L_matrix_complex64 = input_data - upper_triangle_complex64

        # del input_data
        # del upper_triangle_complex64
        # # 释放显存
        # torch.cuda.empty_cache()

        # batch = 1
        # size = L_matrix_complex64.size(1)
        # print("Tensor shape:", L_matrix_complex64.shape)
        # if L_matrix_complex64.dim() == 2:
        #     U_matrix_complex64 = L_matrix_complex64.transpose(0, 1).conj()
        #     A_complex64 = torch.mm(L_matrix_complex64, U_matrix_complex64)
        #     del L_matrix_complex64
        #     torch.cuda.empty_cache()
        #     eye_complex64 = torch.eye(size, dtype=torch.complex64).cuda(cuda_count)
        #     A_complex64 = A_complex64 + eye_complex64
        # elif L_matrix_complex64.dim() == 3:
        #     U_matrix_complex64 = L_matrix_complex64.transpose(1, 2).conj()
        #     A_complex64 = torch.bmm(L_matrix_complex64, U_matrix_complex64)
        #     batch = L_matrix_complex64.size(0)
        #     del L_matrix_complex64
        #     torch.cuda.empty_cache()
        #     eye_complex64 = torch.eye(size, dtype=torch.complex64).expand(batch, -1, -1).cuda(cuda_count)
        #     A_complex64 = A_complex64 + eye_complex64
        # else:
        #     exit()

        # del eye_complex64
        # del U_matrix_complex64
        # torch.cuda.empty_cache()

        # set_complex_data(input_tensor, A_complex64)

        # print("input A:")
        # print_matrix(A_complex64.cpu().numpy())

        result_LU_complex64, pivots = torch.linalg.lu_factor(input_data,pivot=flag)
        
        result_L_complex64, result_U_complex64 = extract_LU(result_LU_complex64, pivots)
        # del A_complex64
        torch.cuda.empty_cache()
        
        if result_mul or mode == 1 :
            if result_L_complex64.dim() == 2:
                result = torch.mm(result_L_complex64, result_U_complex64)
            else:
                result = torch.bmm(result_L_complex64, result_U_complex64)
        else:
            result = result_LU_complex64
            
        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)
            
        # set_values_below_threshold(result)
        # set_diag_imag_one(result)
        set_complex_data(output_tensor, result)
        # print("result complex64 result:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv complex64")
        # print_matrix(pivots.cpu().numpy())
        del pivots
        del input_real_data
        del input_imag_data
        del result_L_complex64
        del result_U_complex64
        del result_LU_complex64
        del result
        torch.cuda.empty_cache()

            # output_result = result.cpu().numpy()
            # output_tensor.setData(output_result)


        # print("result1:")
        # print_matrix(result1.cpu().numpy())

        # output_result_complex64 = result_L_complex64.cpu().numpy()

        # output_real = np.real(output_result_complex64)

        # output_imag = np.imag(output_result_complex64)

        # output_tensor.setComplexData(output_real, output_imag)

        
        # del result_L_complex64
        # torch.cuda.empty_cache()

        # print_matrix(input_data_complex128)

        # upper_triangle_complex128 = torch.triu(input_data_complex128, diagonal=1)
        # L_matrix_complex128 = input_data_complex128 - upper_triangle_complex128
        # L_matrix_complex128 = input_data_complex128

        # del input_data_complex128
        # del upper_triangle_complex128
        # torch.cuda.empty_cache()

        
        # if L_matrix_complex128.dim() == 2:
 
        #     A_complex128 = torch.mm(L_matrix_complex128, L_matrix_complex128.transpose(0, 1).conj()) + torch.eye(size, dtype=torch.complex128).cuda(cuda_count)
        #     del L_matrix_complex128
        #     torch.cuda.empty_cache()  
        # elif L_matrix_complex128.dim() == 3:

            
        #     A_complex128 = torch.bmm(L_matrix_complex128, L_matrix_complex128.transpose(1, 2).conj()) 
        #     del L_matrix_complex128
        #     torch.cuda.empty_cache()  
        #     A_complex128 = A_complex128 + torch.eye(size, dtype=torch.complex128).expand(batch, -1, -1).cuda(cuda_count)
        
            
        # print_matrix(A_complex64.cpu().numpy())
        
        # print_matrix(result_L_complex64.cpu().numpy())

        result_LU_complex128, pivots = torch.linalg.lu_factor(input_data_complex128,pivot=flag)
        
        result_L_complex128, result_U_complex128 = extract_LU(result_LU_complex128, pivots)
        # del A_complex128
        torch.cuda.empty_cache()

        base_node = DataNode("complex128")

        if result_mul:
            if result_L_complex128.dim() == 2:
                result = torch.mm(result_L_complex128, result_U_complex128)
            else:
                result = torch.bmm(result_L_complex128, result_U_complex128)

        else:
           result = result_LU_complex128

        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)
        
        # set_diag_imag_one(result)
        # print("result 128:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv complex128")
        # print_matrix(pivots.cpu().numpy())
        set_complex_data(base_node, result)
        del result_LU_complex128
        del result_L_complex128
        del result_U_complex128
        del pivots
        del result
        del input_data
        del input_data_complex128
        torch.cuda.empty_cache()


        # print("result:")
        # print_matrix(result_L_complex128.cpu().numpy())

        # output_result_complex128 = result_L_complex128.cpu().numpy()

        # print("A_complex128 result:")
        # print_matrix(A_complex128)

        # print("complex128 result:")
        # print_matrix(output_result_complex128)

        

        # print_matrix(output_result_complex128)

        

        # output_real_fp64 = np.real(output_result_complex128)

        # output_imag_fp64 = np.imag(output_result_complex128)

        

        # base_node.setComplexData(output_real_fp64, output_imag_fp64)

        half_dynamic_threshold = 1e-3
        float_dynamic_threshold = 1e-5
        eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
        diff_3_2 = eva.computeDiff3_2(1.0)
        diff1 = eva.computeDiff1()
        diff2 = eva.computeDiff2()
        print("diff1: ", diff1)
        print("diff2: ", diff2)
        print("diff_3_2: ", diff_3_2)
        output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

    # print("还存在的变量:", locals())
    local_vars = list(locals().keys())
    # 删除所有局部变量
    for var in local_vars:
        del locals()[var]
    torch.cuda.empty_cache()

    # if output_is_complex:
    #     # set base node
    #     base_node = DataNode("complex128")
    #     base_node.setComplexData(output_real_fp64, output_imag_fp64)
    #     # set output tensor
    #     # if pytorch do not support half fft, convert output from float32 to float16
    #     if output_is_half:
    #         output_real = output_real.astype("float16")
    #         output_imag = output_imag.astype("float16")
    #     has_inf = np.isinf(output_real).any() or np.isinf(output_imag).any() or \
    #               np.isinf(output_real_fp64).any() or np.isinf(output_imag_fp64).any()
    #     has_nan = np.isnan(output_real).any() or np.isnan(output_imag).any() or \
    #               np.isnan(output_real_fp64).any() or np.isnan(output_imag_fp64).any()
    #     output_tensor.setComplexData(output_real, output_imag)
    # else:
    #     # set base node
    #     base_node = DataNode("double")
    #     base_node.setData(output_result_fp64)
    #     # set output tensor
    #     # if pytorch do not support half fft, convert output from float32 to float16
    #     if output_is_half:
    #         output_result = output_result.astype("float16")
    #     has_inf = np.isinf(output_result_fp64).any() or np.isinf(output_result).any()
    #     has_nan = np.isnan(output_result_fp64).any() or np.isnan(output_result).any()
    #     output_tensor.setData(output_result)

@registerProtoWriter("sgetrf2")
class sgetrf2ProtoWriter(MluOpProtoWriter):
def dumpOpParam2Node(self):
sgetrf2_param_node = self.proto_node_.sgetrf2_param
sgetrf2_param_node.mode = self.op_params_.get("mode")
# sgetrf2_param_node.n.extend(self.op_params_.get("n"))
# sgetrf2_param_node.direction = self.op_params_.get("direction")
# sgetrf2_param_node.scale_factor = self.op_params_.get("scale_factor")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ICT New Op Contribute a new operator
Projects
None yet
Development

No branches or pull requests

3 participants