Real-ESRGAN / main.py
qqc1989's picture
Upload 8 files
dfc6b8a verified
import cv2
import numpy as np
import axengine as ort
import time
import argparse
def get_model(model_path: str) -> ort.InferenceSession:
model = ort.InferenceSession(model_path)
for input in model.get_inputs():
print(input.name, input.shape, input.dtype)
for output in model.get_outputs():
print(output.name, output.shape, output.dtype)
width = model.get_inputs()[0].shape[2]
height = model.get_inputs()[0].shape[1]
return model, width, height
def preprocess_image(image, width=64, height=64):
# 获取原始图像的高度和宽度
h, w = image.shape[:2]
# 计算调整大小的比例
scale_ratio = min(width / w, height / h)
# 根据比例计算新的高度和宽度,同时保持原图宽高比
new_w = int(w * scale_ratio)
new_h = int(h * scale_ratio)
# 调整图像大小,保持原图宽高比
resized_img = cv2.resize(image, (new_w, new_h))
# 创建一个具有目标尺寸的空白图像(黑色背景)
letterboxed_img = np.full((height, width, 3), (0, 0, 0), dtype=np.uint8)
# 计算将调整大小后的图像放置在letterbox中的起始点
top = (height - new_h) // 2
left = (width - new_w) // 2
# 将调整大小后的图像放入letterboxed图像中
letterboxed_img[top:top + new_h, left:left + new_w] = resized_img
# 添加批次维度
data = np.expand_dims(letterboxed_img, axis=0)
return data
'''
def preprocess_image(image, width=64, height=64):
data = cv2.resize(image, (width, height))
data = np.expand_dims(data, axis=0)
return data
'''
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process an image with a given model.")
parser.add_argument('--input', type=str, required=True, help='Path to the input image.')
parser.add_argument('--output', type=str, required=True, help='Path to save the output image.')
parser.add_argument('--model', type=str, required=True, help='Path to the model file (.axmodel).')
args = parser.parse_args()
model, width, height = get_model(args.model)
img = cv2.imread(args.input)
print("Original Image Shape:", img.shape)
img = preprocess_image(img, width, height)
print("Preprocessed Image Shape:", img.shape)
# 开始计时
start_time = time.time()
# 执行推理
output = model.run(None, {"input.1": img})[0]
# 结束计时并计算耗时(毫秒)
end_time = time.time()
elapsed_ms = (end_time - start_time) * 1000 # 秒转毫秒
print(f"Inference Time: {elapsed_ms:.2f} ms")
print("Output Shape:", output.shape)
output[output>1] = 1
output[output<0] = 0
output_img = (output * 255).astype(np.uint8)[0]
print("Final Output Image Shape:", output_img.shape)
cv2.imwrite(args.output, output_img)