Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability outputs #1311

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 139 additions & 1 deletion easyocr/easyocr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from .recognition import get_recognizer, get_text
from .recognition import get_recognizer, get_text, get_text_prob
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
download_and_unzip, printProgressBar, diff, reformat_input,\
make_rotated_img_list, set_result_with_confidence,\
Expand Down Expand Up @@ -350,6 +350,93 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\

return horizontal_list_agg, free_list_agg

def recognize_prob(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None,paragraph = False,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'):

if reformat:
img, img_cv_grey = reformat_input(img_cv_grey)

if allowlist:
ignore_char = ''.join(set(self.character)-set(allowlist))
elif blocklist:
ignore_char = ''.join(set(blocklist))
else:
ignore_char = ''.join(set(self.character)-set(self.lang_char))

if self.model_lang in ['chinese_tra','chinese_sim']: decoder = 'greedy'

if (horizontal_list==None) and (free_list==None):
y_max, x_max = img_cv_grey.shape
horizontal_list = [[0, x_max, 0, y_max]]
free_list = []

# without gpu/parallelization, it is faster to process image one by one
if ((batch_size == 1) or (self.device == 'cpu')) and not rotation_info:
result = []
for bbox in horizontal_list:
h_list = [bbox]
f_list = []
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
result += result0
for bbox in free_list:
h_list = []
f_list = [bbox]
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
result += result0
# default mode will try to process multiple boxes at the same time
else:
image_list, max_width = get_image_list(horizontal_list, free_list, img_cv_grey, model_height = imgH)
image_len = len(image_list)
if rotation_info and image_list:
image_list = make_rotated_img_list(rotation_info, image_list)
max_width = max(max_width, imgH)

result = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)

if rotation_info and (horizontal_list+free_list):
# Reshape result to be a list of lists, each row being for
# one of the rotations (first row being no rotation)
result = set_result_with_confidence(
[result[image_len*i:image_len*(i+1)] for i in range(len(rotation_info) + 1)])

if self.model_lang == 'arabic':
direction_mode = 'rtl'
result = [list(item) for item in result]
for item in result:
item[1] = get_display(item[1])
else:
direction_mode = 'ltr'

if paragraph:
result = get_paragraph(result, x_ths=x_ths, y_ths=y_ths, mode = direction_mode)

if detail == 0:
return [item[1] for item in result]
elif output_format == 'dict':
if paragraph:
return [ {'boxes':item[0],'text':item[1]} for item in result]
return [ {'boxes':item[0],'text':item[1],'confident':item[2]} for item in result]
elif output_format == 'json':
if paragraph:
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1]}, ensure_ascii=False) for item in result]
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1],'confident':item[2]}, ensure_ascii=False) for item in result]
elif output_format == 'free_merge':
return merge_to_free(result, free_list)
else:
return result

def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
Expand Down Expand Up @@ -472,6 +559,42 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtext_prob(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None, paragraph = False, min_size = 20,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
canvas_size = 2560, mag_ratio = 1.,\
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1,
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
output_format='standard'):
'''
Parameters:
image: file path or numpy-array or a byte stream object
'''
img, img_cv_grey = reformat_input(image)

horizontal_list, free_list = self.detect(img,
min_size = min_size, text_threshold = text_threshold,\
low_text = low_text, link_threshold = link_threshold,\
canvas_size = canvas_size, mag_ratio = mag_ratio,\
slope_ths = slope_ths, ycenter_ths = ycenter_ths,\
height_ths = height_ths, width_ths= width_ths,\
add_margin = add_margin, reformat = False,\
threshold = threshold, bbox_min_score = bbox_min_score,\
bbox_min_size = bbox_min_size, max_candidates = max_candidates
)
# get the 1st result from hor & free list as self.detect returns a list of depth 3
horizontal_list, free_list = horizontal_list[0], free_list[0]
result = self.recognize_prob(img_cv_grey, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtextlang(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
Expand Down Expand Up @@ -577,3 +700,18 @@ def readtext_batched(self, image, n_width=None, n_height=None,\
filter_ths, y_ths, x_ths, False, output_format))

return result_agg


def convert_prob_to_word(prob, converter):
"""
For use with the readtest_prob outputs.

- prob should be 2d
- convert = reader.converter
"""
assert prob.ndim == 2
preds_index = np.argmax(prob, axis=1)
preds_index = preds_index.flatten()
preds_size = np.array([prob.shape[0]])
preds_str = converter.decode_greedy(preds_index, preds_size)[0]
return preds_str
91 changes: 91 additions & 0 deletions easyocr/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\

return result

def recognizer_predict_prob(model, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu'):
model.eval()
result = []
with torch.no_grad():
for image_tensors in test_loader:
batch_size = image_tensors.size(0)
image = image_tensors.to(device)
# For max length prediction
length_for_pred = torch.IntTensor([batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, batch_max_length + 1).fill_(0).to(device)

preds = model(image, text_for_pred)

# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)

######## filter ignore_char, rebalance
preds_prob = F.softmax(preds, dim=2)
preds_prob = preds_prob.cpu().detach().numpy()
preds_prob[:,:,ignore_idx] = 0.
pred_norm = preds_prob.sum(axis=2)
preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
preds_prob = torch.from_numpy(preds_prob).float().to(device)
preds_prob = preds_prob.cpu().detach().numpy()

values = preds_prob.max(axis=2)
indices = preds_prob.argmax(axis=2)
preds_max_prob = []
for v,i in zip(values, indices):
max_probs = v[i!=0] # this removes blanks
if len(max_probs)>0:
preds_max_prob.append(max_probs)
else:
preds_max_prob.append(np.array([0]))

for pred_max_prob in preds_max_prob:
confidence_score = custom_mean(pred_max_prob)
result.append([preds_prob, confidence_score])

return result

def get_recognizer(recog_network, network_params, character,\
separator_list, dict_list, model_path,\
device = 'cpu', quantize = True):
Expand Down Expand Up @@ -231,3 +273,52 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\
result.append( (box, pred1[0], pred1[1]) )

return result

def get_text_prob(character, imgH, imgW, recognizer, converter, image_list,\
ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\
adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'):
batch_max_length = int(imgW/10)

char_group_idx = {}
ignore_idx = []
for char in ignore_char:
try: ignore_idx.append(character.index(char)+1)
except: pass

coord = [item[0] for item in image_list]
img_list = [item[1] for item in image_list]
AlignCollate_normal = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True)
test_data = ListDataset(img_list)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, shuffle=False,
num_workers=int(workers), collate_fn=AlignCollate_normal, pin_memory=True)

# predict first round
result1 = recognizer_predict_prob(recognizer, converter, test_loader,batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)

# predict second round
low_confident_idx = [i for i,item in enumerate(result1) if (item[1] < contrast_ths)]
if len(low_confident_idx) > 0:
img_list2 = [img_list[i] for i in low_confident_idx]
AlignCollate_contrast = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True, adjust_contrast=adjust_contrast)
test_data = ListDataset(img_list2)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, shuffle=False,
num_workers=int(workers), collate_fn=AlignCollate_contrast, pin_memory=True)
result2 = recognizer_predict_prob(recognizer, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)

result = []
for i, zipped in enumerate(zip(coord, result1)):
box, pred1 = zipped
if i in low_confident_idx:
pred2 = result2[low_confident_idx.index(i)]
if pred1[1]>pred2[1]:
result.append( (box, pred1[0], pred1[1]) )
else:
result.append( (box, pred2[0], pred2[1]) )
else:
result.append( (box, pred1[0], pred1[1]) )

return result