diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cb5263e --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea* +*__pycache__ + +static/output + diff --git a/readme.md b/README.md similarity index 100% rename from readme.md rename to README.md diff --git a/app.py b/app.py index abdc848..88a23d6 100644 --- a/app.py +++ b/app.py @@ -1,68 +1,41 @@ -from flask import Flask, render_template, request, jsonify, send_from_directory, abort,send_file -from flask_cors import CORS import time -from demo import * -from preprocessing_img import * +from flask import Flask, request +from flask_cors import CORS +from py_profiler.profiler_controller import profiler_blueprint + +from enhance_service import * app = Flask(__name__) +app.register_blueprint(profiler_blueprint) CORS(app) -client_count=0 #to seperate -@app.route('/') -def index(): - return "hello" +folder_out = "static/output/" +os.makedirs(folder_out, exist_ok=True) + +enhance_service = EnhanceService( + 'Deblurring/pretrained_models/model_deblurring.pth', + use_cpu=True +) -@app.route('/enhance',methods=['POST', 'GET']) -def enhance0(): - if request.method == 'POST': - global client_count - urls=dict(request.get_json())['urls'] - # - folder_in="static/clients/input/"+format(client_count,'04d')+"/" - folder_out="static/clients/output/"+format(client_count,'04d')+"/" - # - try: - os.makedirs(folder_in) - except: - pass - try: - os.makedirs(folder_out) - except: - pass - clear_folder(folder_in) - clear_folder(folder_out) - begin=time.time() - im_pre=Im_preprocess() - im_pre.close() - im_pre.from_urls_to_array(urls) - im_pre.write_img(folder_in) - process(folder_in,folder_out) +@app.route('/') +def index(): + return "hello" - img_files=os.listdir(folder_out) - img_files.sort() - return_list=[] - for i in im_pre.correct_idx: - return_list.append({ - 'before':urls[im_pre.correct_idx[i]], - 'after':'/get_image/'+format(client_count,'04d')+'/'+img_files[i] - }) - client_count+=1 - if client_count%499==0: - client_count=0 +@app.route('/enhance', methods=['POST']) +def enhance(): + urls = dict(request.get_json())['urls'] + # + begin = time.time() + output_result = enhance_service.process(urls, folder_out) + print(f'Output: {output_result}') return { - 'time':time.time()-begin, - 'result':return_list + 'time': time.time() - begin, + 'result': output_result } - return {'message':'You need to use POST metod'} - -@app.route('/get_image//') -def get_image(client_count,img_name): - print('./static/clients/output'+format(client_count,'04d')+"/"+img_name) - return send_file('./static/clients/output/'+format(client_count,'04d')+"/"+img_name,attachment_filename=img_name,as_attachment=False) if __name__ == '__main__': - app.run() \ No newline at end of file + app.run() diff --git a/demo.py b/demo.py deleted file mode 100644 index e322025..0000000 --- a/demo.py +++ /dev/null @@ -1,175 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms.functional as TF -from PIL import Image -import os -from runpy import run_path -from skimage import img_as_ubyte -from collections import OrderedDict -from natsort import natsorted -from glob import glob -import cv2 -import argparse -#from enlighten_inference import EnlightenOnnxModel -import time -import numpy as np -import os -from IAGCWD import White_Balancer -from CEIQ import CEIQ -import threading -# from MEON_demo import MEON_eval -# torch.cuda.set_per_process_memory_fraction(0.8, device=None) -# torch.cuda.empty_cache() -# os.environ["CUDA_VISIBLE_DEVICES"]="" - - - -def save_img(filepath, img): - # cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) - cv2.imwrite(filepath, img) - -def load_checkpoint(model, weights): - checkpoint = torch.load(weights) - # checkpoint = torch.load(weights, map_location='cpu') - try: - model.load_state_dict(checkpoint["state_dict"]) - except: - state_dict = checkpoint["state_dict"] - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] # remove `module.` - new_state_dict[name] = v - model.load_state_dict(new_state_dict) - -task = 'Deblurring' -load_file = run_path(os.path.join(task, "MPRNet.py")) -model = load_file['MPRNet']() -model.cuda() - -weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth") -load_checkpoint(model, weights) -model.eval() - -white_balancer = White_Balancer() -# Load CEIQ model - -CEIQ_model = CEIQ() -#os.makedirs(out_dir, exist_ok=True) - -# Remove all files in these two folder -import os, shutil -def empty_dir(dir): - for filename in os.listdir(dir): - file_path = os.path.join(dir, filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) -#empty_dir(out_dir) -print("Flushing successfully") - -def process_one_image(file_,img_multiple_of,enhanced_out_dir,mutex): - img = Image.open(file_).convert('RGB') - input_ = TF.to_tensor(img).unsqueeze(0).cuda() - # input_ = TF.to_tensor(img).unsqueeze(0) - - - # Pad the input if not_multiple_of 8 - h,w = input_.shape[2], input_.shape[3] - H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of - padh = H-h if h%img_multiple_of!=0 else 0 - padw = W-w if w%img_multiple_of!=0 else 0 - input_ = F.pad(input_, (0,padw,0,padh), 'reflect') - - mutex.acquire() - with torch.no_grad(): - restored = model(input_) - mutex.release() - restored = restored[0] - restored = torch.clamp(restored, 0, 1) - - # Unpad the output - restored = restored[:,:,:h,:w] - restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() - - restored = img_as_ubyte(restored[0]) - # processed = enlighten_model.predict(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) - # processed = img_as_ubyte(processed) - - ### White balancing ### - # Extract intensity component of the image - deblurred_img = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR) - YCrCb = cv2.cvtColor(deblurred_img, cv2.COLOR_BGR2YCrCb) - Y = YCrCb[:,:,0] - # Determine whether image is bright or dimmed - threshold = 0.3 - exp_in = 112 # Expected global average intensity - M,N = deblurred_img.shape[:2] - mean_in = np.sum(Y/(M*N)) - t = (mean_in - exp_in)/ exp_in - - # Process image for gamma correction - img_output = None - if t < -threshold: # Dimmed Image - print (file_ + ": Dimmed") - result = white_balancer.process_dimmed(Y) - YCrCb[:,:,0] = result - img_output = cv2.cvtColor(YCrCb,cv2.COLOR_YCrCb2BGR) - elif t > threshold: - print (file_ + ": Bright Image") # Bright Image - result = white_balancer.process_bright(Y) - YCrCb[:,:,0] = result - img_output = cv2.cvtColor(YCrCb,cv2.COLOR_YCrCb2BGR) - else: - img_output = deblurred_img - restored = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR) - # Compute CEIQ score and decide whether the image was significantly enhanced or not - # org_score = CEIQ_model.predict(np.expand_dims(restored, axis=0), 1)[0] - # imp_score = CEIQ_model.predict(np.expand_dims(img_output, axis=0), 1)[0] - scores = CEIQ_model.predict([restored, img_output], 1) - # if file_.split('/')[-1] == '20210420050550-de27_wm.jpg': - # save_img(os.path.join('tmp_folder', f+'[restored].png'), restored) - # save_img(os.path.join('tmp_folder', f+'[img_output].png'), img_output) - # print(f"Scoreeeee: {(scores[0], scores[1])}") - if scores[0] > scores[1]: - img_output = restored - - f = os.path.splitext(os.path.split(file_)[-1])[0] - #save_img(os.path.join(out_dir, f+'.png'), restored) - save_img(os.path.join(enhanced_out_dir, f+'.png'), img_output) - # cv2.imwrite(os.path.join(enhanced_out_dir, f+'.png'), img_output) - -def process(inp_dir, enhanced_out_dir): - files = natsorted(glob(os.path.join(inp_dir, '*.jpg')) - + glob(os.path.join(inp_dir, '*.JPG')) - + glob(os.path.join(inp_dir, '*.jpeg')) - + glob(os.path.join(inp_dir, '*.JPEG')) - + glob(os.path.join(inp_dir, '*.png')) - + glob(os.path.join(inp_dir, '*.PNG'))) - - if len(files) == 0: - raise Exception(f"No files found at {inp_dir}") - - # Load corresponding model architecture and weights - - img_multiple_of = 8 - #enlighten_model = EnlightenOnnxModel() - - start_time = time.monotonic() - print('Number of files: ', len(files)) - # print(files) - threads=[] - mutex = threading.Lock() - for file_ in files: - t=threading.Thread(target=process_one_image,args=[file_,img_multiple_of,enhanced_out_dir,mutex]) - t.start() - threads.append(t) - for thread in threads: - thread.join() - - print(f"Processing Time: {time.monotonic() - start_time}") - print(f"Enhanced images saved at {enhanced_out_dir}") diff --git a/enhance_service.py b/enhance_service.py new file mode 100644 index 0000000..912b6a5 --- /dev/null +++ b/enhance_service.py @@ -0,0 +1,158 @@ +import os +# from enlighten_inference import EnlightenOnnxModel +import uuid +from concurrent import futures +from concurrent.futures import ThreadPoolExecutor +from typing import List + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from py_profiler import profiler +from skimage import img_as_ubyte + +from CEIQ import CEIQ +from Deblurring.MPRNet import MPRNet +# from MEON_demo import MEON_eval +# torch.cuda.set_per_process_memory_fraction(0.8, device=None) +# torch.cuda.empty_cache() +# os.environ["CUDA_VISIBLE_DEVICES"]="" +from image_downloader import HybirdImageDownloader +from utils.model_utils import load_checkpoint, save_img +from white_balancer import WhiteBalancer + + +class EnhanceService: + + def __init__(self, deblur_model_path, use_cpu: bool = True): + self.use_cpu = use_cpu + # Executor to run enhance process concurrently + self.executor = ThreadPoolExecutor(max_workers=8) + # A downloader to download image using a thread pool with 16 threads + self.image_downloader = HybirdImageDownloader(16) + + # task = 'Deblurring' + # load_file = run_path(os.path.join(task, "MPRNet.py")) + # model = load_file['MPRNet']() # Type: MPRNet + self.deblur_model = MPRNet() + if self.use_cpu is not True: + self.deblur_model.cuda() + + load_checkpoint(self.deblur_model, deblur_model_path, use_cpu=use_cpu) + self.deblur_model.eval() + + self.white_balancer = WhiteBalancer() + self.ceiq_scoring_model = CEIQ() + + print(f'white_balancer: {type(self.white_balancer)}') + print(f'CEIQ_model: {type(self.ceiq_scoring_model)}') + print(f'Model: {type(self.deblur_model)}') + + print("Init successfully") + + @profiler() + def process(self, image_urls: List[str], enhanced_out_dir): + image_dict = self.image_downloader.bulk_download_as_image(image_urls) + if len(image_urls) == 0: + raise Exception(f"No image urls found at {image_urls}") + print('Number of files: ', len(image_dict)) + + future_to_checks = { + self.executor.submit(self._enhance_image, image, 8, enhanced_out_dir): url + for url, image in image_dict.items() + } + + result_dict = {} + # Now it comes to the result of each check + # The try-except-else clause is omitted here + for future in futures.as_completed(future_to_checks): + url = future_to_checks[future] + output_path = future.result() + result_dict[url] = output_path + return result_dict + + @profiler() + def _enhance_image(self, image, factor, out_dir) -> str: + restored = self._deblur_image(image, factor) + # processed = enlighten_model.predict(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) + # processed = img_as_ubyte(processed) + img_output = self._process_white_balancing(restored) + + restored = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR) + + scores = self._calc_score([restored, img_output]) + if scores[0] > scores[1]: + img_output = restored + + output_path = os.path.join(out_dir, f'{uuid.uuid1()}.png') + save_img(output_path, img_output) + + return output_path + + @profiler() + def _deblur_image(self, img, factor: int = 8): + img = img.convert('RGB') + input_image_as_tensor = TF.to_tensor(img).unsqueeze(0) if self.use_cpu else TF.to_tensor(img).unsqueeze( + 0).cuda() + + # Pad the input if not_multiple_of 8 + h, w = input_image_as_tensor.shape[2], input_image_as_tensor.shape[3] + H, W = ((h + factor) // factor) * factor, ( + (w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + input_image_as_tensor = F.pad(input_image_as_tensor, (0, padw, 0, padh), 'reflect') + + with torch.no_grad(): + restored = self.deblur_model(input_image_as_tensor) + restored = restored[0] + restored = torch.clamp(restored, 0, 1) + + # Unpad the output + restored = restored[:, :, :h, :w] + restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() + restored = img_as_ubyte(restored[0]) + + return restored + + @profiler() + def _process_white_balancing(self, input_image, threshold: float = 0.3): + ### White balancing ### + # Extract intensity component of the image + deblurred_img = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) + YCrCb = cv2.cvtColor(deblurred_img, cv2.COLOR_BGR2YCrCb) + Y = YCrCb[:, :, 0] + # Determine whether image is bright or dimmed + + exp_in = 112 # Expected global average intensity + M, N = deblurred_img.shape[:2] + mean_in = np.sum(Y / (M * N)) + t = (mean_in - exp_in) / exp_in + + # Process image for gamma correction + output_image = None + if t < -threshold: # Dimmed Image + print('Dimmed') + result = self.white_balancer.process_dimmed(Y) + YCrCb[:, :, 0] = result + output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR) + elif t > threshold: + print('Bright Image') # Bright Image + result = self.white_balancer.process_bright(Y) + YCrCb[:, :, 0] = result + output_image = cv2.cvtColor(YCrCb, cv2.COLOR_YCrCb2BGR) + else: + output_image = deblurred_img + + return output_image + + @profiler() + def _calc_score(self, images): + # Compute CEIQ score and decide whether the image was significantly enhanced or not + # org_score = CEIQ_model.predict(np.expand_dims(restored, axis=0), 1)[0] + # imp_score = CEIQ_model.predict(np.expand_dims(img_output, axis=0), 1)[0] + scores = self.ceiq_scoring_model.predict(images, option=1) + print(f"Scores: {scores[0]} -> {scores[1]}: Improved: {((scores[1] - scores[0]) * 100 / scores[0])} %") + return scores diff --git a/image_downloader.py b/image_downloader.py new file mode 100644 index 0000000..f51cc9e --- /dev/null +++ b/image_downloader.py @@ -0,0 +1,65 @@ +import io +from concurrent import futures +from concurrent.futures import ThreadPoolExecutor +from typing import Dict +from urllib.request import Request, urlopen + +from PIL import Image +from py_profiler import profiler + + +class ImageDownloader: + + def __init__(self, num_threads: int = 8): + self.num_threads = num_threads + self.executor = ThreadPoolExecutor(max_workers=num_threads) + + def bulk_download_as_image(self, image_urls: list) -> dict: + + r: Dict[str, bytearray] = self.bulk_download(image_urls) + + result_dict = {} + for k, v in r.items(): + result_dict[k] = Image.open(io.BytesIO(v)) + + return result_dict + + @profiler(f'{__qualname__}.bulk_download') + def bulk_download(self, image_urls: list) -> Dict[str, bytearray]: + future_to_checks = { + self.executor.submit(self.download_url, url): url + for url in image_urls + } + + result_dict = {} + # Now it comes to the result of each check + # The try-except-else clause is omitted here + for future in futures.as_completed(future_to_checks): + url = future_to_checks[future] + image_bytearray = future.result() + + result_dict[url] = image_bytearray + + return result_dict + + def download_url(self, path: str) -> bytearray: + pass + + +class HybirdImageDownloader(ImageDownloader): + + @profiler("download_url") + def download_url(self, path: str) -> bytearray: + if path.startswith('http:') or path.startswith('https:'): + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" + } + import ssl + gcontext = ssl.SSLContext() + req = Request(path, headers=headers) + res = urlopen(req, context=gcontext) + raw = bytearray(res.read()) + else: + with open(path, 'rb') as reader: + raw = bytearray(reader.read()) + return raw diff --git a/img_preprocessor.py b/img_preprocessor.py new file mode 100644 index 0000000..7ddd541 --- /dev/null +++ b/img_preprocessor.py @@ -0,0 +1,45 @@ +import threading + +from skimage import io + +from utils.model_utils import bypass_ssl_verify + +bypass_ssl_verify() +mutex = threading.Lock() + + +class ImgPreprocessor: + correct_idx = [] # index in urls that have been downloaded normally + img_list = [] + + def from_urls_to_array(self, urls): + threads = [] + for i, url in enumerate(urls): + t = threading.Thread(target=self.load_one_image, args=[url, i]) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + def load_one_image(self, url, i): + try: + print(f'Load image: {url}') + img = io.imread(url) # load concurrently + + mutex.acquire() # mutex lock + self.img_list.append(img) + self.correct_idx.append(i) + mutex.release() # mutex release + except Exception as err: + print(err) + pass + + def write_img(self, path): + for idx, img in enumerate(self.img_list): + print(f'Save image: {idx}: {len(img)}') + io.imsave(path + format(idx, '04d') + ".png", img) + + def close(self): + self.correct_idx = [] + self.img_list = [] diff --git a/preprocessing_img.py b/preprocessing_img.py deleted file mode 100644 index 34201e4..0000000 --- a/preprocessing_img.py +++ /dev/null @@ -1,39 +0,0 @@ -from skimage import io -import threading -import os -def clear_folder(folder_name): - a=os.listdir(folder_name) - for file_ in a: - os.remove(os.path.join(folder_name,file_)) - print(str(len(a))+" files in '"+folder_name+"' folder deleted") -mutex = threading.Lock() -class Im_preprocess: - correct_idx=[] #index in urls that have been downloaded normally - img_list=[] - def from_urls_to_array(self,urls): - threads=[] - for i in range(len(urls)): - t=threading.Thread(target=self.load_one_image,args=[urls,i]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() - - def load_one_image(self,urls,i): - try: - img=io.imread(urls[i])#load concurrently - - mutex.acquire() #mutex lock - self.img_list.append(img) - self.correct_idx.append(i) - mutex.release() #mutex release - except: - pass - - def write_img(self, path): - for i in range(len(self.img_list)): - io.imsave( path+format(i,'04d')+".png",self.img_list[i]) - def close(self): - self.correct_idx=[] - self.img_list=[] diff --git a/requirements.txt b/requirements.txt index 8c2a79e..51b8790 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ torch==1.9.0 torchvision==0.10.0 typing-extensions==3.10.0.0 Werkzeug==2.0.1 +py_profiler \ No newline at end of file diff --git a/static/clients/input/0000/0000.png b/static/clients/input/0000/0000.png deleted file mode 100644 index c1bc4de..0000000 Binary files a/static/clients/input/0000/0000.png and /dev/null differ diff --git a/static/clients/input/0000/0001.png b/static/clients/input/0000/0001.png deleted file mode 100644 index 7fa5268..0000000 Binary files a/static/clients/input/0000/0001.png and /dev/null differ diff --git a/static/clients/input/0001/0000.png b/static/clients/input/0001/0000.png deleted file mode 100644 index 7fa5268..0000000 Binary files a/static/clients/input/0001/0000.png and /dev/null differ diff --git a/static/clients/input/0001/0001.png b/static/clients/input/0001/0001.png deleted file mode 100644 index c1bc4de..0000000 Binary files a/static/clients/input/0001/0001.png and /dev/null differ diff --git a/static/clients/output/0001/0000.png b/static/clients/output/0001/0000.png deleted file mode 100644 index 8c5e6e5..0000000 Binary files a/static/clients/output/0001/0000.png and /dev/null differ diff --git a/static/clients/output/0001/0001.png b/static/clients/output/0001/0001.png deleted file mode 100644 index afe5a0d..0000000 Binary files a/static/clients/output/0001/0001.png and /dev/null differ diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000..4cc0c8f --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,49 @@ +# from enlighten_inference import EnlightenOnnxModel +import os +import shutil +import ssl +from collections import OrderedDict + +import torch + + +def bypass_ssl_verify(): + if (not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None)): + ssl._create_default_https_context = ssl._create_unverified_context + +def save_img(filepath, img): + import cv2 + # cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + cv2.imwrite(filepath, img) + + +def clear_folder(folder_name): + a = os.listdir(folder_name) + for file_ in a: + os.remove(os.path.join(folder_name, file_)) + print(str(len(a)) + " files in '" + folder_name + "' folder deleted") + + +def empty_dir(dir_path: str): + for filename in os.listdir(dir_path): + file_path = os.path.join(dir_path, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) + + +def load_checkpoint(model, weights: str, use_cpu: bool = True): + checkpoint = torch.load(weights, map_location='cpu') if use_cpu else torch.load(weights) + try: + model.load_state_dict(checkpoint["state_dict"]) + except: + state_dict = checkpoint["state_dict"] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + model.load_state_dict(new_state_dict) diff --git a/IAGCWD.py b/white_balancer.py similarity index 59% rename from IAGCWD.py rename to white_balancer.py index 1a7b74f..882cb20 100644 --- a/IAGCWD.py +++ b/white_balancer.py @@ -1,16 +1,13 @@ -import cv2 -from glob import glob -import argparse import numpy as np -from matplotlib import pyplot as plt -from scipy.linalg import fractional_matrix_power -from natsort import natsorted -import os, shutil +from py_profiler import profiler -class White_Balancer: + +class WhiteBalancer: + + @profiler() def image_agcwd(self, img, a=0.25, truncated_cdf=False): - h,w = img.shape[:2] - hist,bins = np.histogram(img.flatten(),256,[0,256]) + h, w = img.shape[:2] + hist, bins = np.histogram(img.flatten(), 256, [0, 256]) cdf = hist.cumsum() cdf_normalized = cdf / cdf.max() prob_normalized = hist / hist.sum() @@ -20,30 +17,32 @@ def image_agcwd(self, img, a=0.25, truncated_cdf=False): intensity_min = unique_intensity.min() prob_min = prob_normalized.min() prob_max = prob_normalized.max() - + pn_temp = (prob_normalized - prob_min) / (prob_max - prob_min) - pn_temp[pn_temp>0] = prob_max * (pn_temp[pn_temp>0]**a) - pn_temp[pn_temp<0] = prob_max * (-((-pn_temp[pn_temp<0])**a)) - prob_normalized_wd = pn_temp / pn_temp.sum() # normalize to [0,1] + pn_temp[pn_temp > 0] = prob_max * (pn_temp[pn_temp > 0] ** a) + pn_temp[pn_temp < 0] = prob_max * (-((-pn_temp[pn_temp < 0]) ** a)) + prob_normalized_wd = pn_temp / pn_temp.sum() # normalize to [0,1] cdf_prob_normalized_wd = prob_normalized_wd.cumsum() - - if truncated_cdf: - inverse_cdf = np.maximum(0.5,1 - cdf_prob_normalized_wd) + + if truncated_cdf: + inverse_cdf = np.maximum(0.5, 1 - cdf_prob_normalized_wd) else: inverse_cdf = 1 - cdf_prob_normalized_wd - + img_new = img.copy() for i in unique_intensity: - img_new[img==i] = np.round(255 * (i / 255)**inverse_cdf[i]) - + img_new[img == i] = np.round(255 * (i / 255) ** inverse_cdf[i]) + return img_new + @profiler() def process_bright(self, img): img_negative = 255 - img agcwd = self.image_agcwd(img_negative, a=0.25, truncated_cdf=False) reversed = 255 - agcwd return reversed + @profiler() def process_dimmed(self, img): agcwd = self.image_agcwd(img, a=0.75, truncated_cdf=True) return agcwd