commit
69d78ddba3
|
@ -63,6 +63,8 @@ public:
|
||||||
|
|
||||||
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||||
|
|
||||||
|
this->rec_batch_num = stoi(config_map_["rec_batch_num"]);
|
||||||
|
|
||||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,6 +88,8 @@ public:
|
||||||
|
|
||||||
double det_db_unclip_ratio = 2.0;
|
double det_db_unclip_ratio = 2.0;
|
||||||
|
|
||||||
|
int rec_batch_num = 30;
|
||||||
|
|
||||||
std::string det_model_dir;
|
std::string det_model_dir;
|
||||||
|
|
||||||
std::string rec_model_dir;
|
std::string rec_model_dir;
|
||||||
|
|
|
@ -40,13 +40,14 @@ public:
|
||||||
const int &gpu_id, const int &gpu_mem,
|
const int &gpu_id, const int &gpu_mem,
|
||||||
const int &cpu_math_library_num_threads,
|
const int &cpu_math_library_num_threads,
|
||||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||||
const string &label_path) {
|
const string &label_path, const int& rec_batch_num) {
|
||||||
this->use_gpu_ = use_gpu;
|
this->use_gpu_ = use_gpu;
|
||||||
this->gpu_id_ = gpu_id;
|
this->gpu_id_ = gpu_id;
|
||||||
this->gpu_mem_ = gpu_mem;
|
this->gpu_mem_ = gpu_mem;
|
||||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||||
this->use_mkldnn_ = use_mkldnn;
|
this->use_mkldnn_ = use_mkldnn;
|
||||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||||
|
this->rec_batch_num_ = rec_batch_num;
|
||||||
|
|
||||||
this->label_list_ = Utility::ReadDict(label_path);
|
this->label_list_ = Utility::ReadDict(label_path);
|
||||||
this->label_list_.push_back(" ");
|
this->label_list_.push_back(" ");
|
||||||
|
@ -69,6 +70,7 @@ private:
|
||||||
int cpu_math_library_num_threads_ = 4;
|
int cpu_math_library_num_threads_ = 4;
|
||||||
bool use_mkldnn_ = false;
|
bool use_mkldnn_ = false;
|
||||||
bool use_zero_copy_run_ = false;
|
bool use_zero_copy_run_ = false;
|
||||||
|
int rec_batch_num_ = 30;
|
||||||
|
|
||||||
std::vector<std::string> label_list_;
|
std::vector<std::string> label_list_;
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ int main(int argc, char **argv) {
|
||||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||||
config.use_mkldnn, config.use_zero_copy_run,
|
config.use_mkldnn, config.use_zero_copy_run,
|
||||||
config.char_list_file);
|
config.char_list_file, config.rec_batch_num);
|
||||||
|
|
||||||
#ifdef USE_MKL
|
#ifdef USE_MKL
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
|
@ -91,11 +91,11 @@ int main(int argc, char **argv) {
|
||||||
auto end = std::chrono::system_clock::now();
|
auto end = std::chrono::system_clock::now();
|
||||||
auto duration =
|
auto duration =
|
||||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||||
std::cout << "花费了"
|
std::cout << "cost"
|
||||||
<< double(duration.count()) *
|
<< double(duration.count()) *
|
||||||
std::chrono::microseconds::period::num /
|
std::chrono::microseconds::period::num /
|
||||||
std::chrono::microseconds::period::den
|
std::chrono::microseconds::period::den
|
||||||
<< "秒" << std::endl;
|
<< "s" << std::endl;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,20 @@
|
||||||
|
|
||||||
#include <include/ocr_rec.h>
|
#include <include/ocr_rec.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
vector<int> argsort(const std::vector<T>& array)
|
||||||
|
{
|
||||||
|
const int array_len(array.size());
|
||||||
|
std::vector<int> array_index(array_len, 0);
|
||||||
|
for (int i = 0; i < array_len; ++i)
|
||||||
|
array_index[i] = i;
|
||||||
|
|
||||||
|
std::sort(array_index.begin(), array_index.end(),
|
||||||
|
[&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); });
|
||||||
|
|
||||||
|
return array_index;
|
||||||
|
}
|
||||||
|
|
||||||
namespace PaddleOCR {
|
namespace PaddleOCR {
|
||||||
|
|
||||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
|
@ -22,100 +36,122 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
img.copyTo(srcimg);
|
img.copyTo(srcimg);
|
||||||
cv::Mat crop_img;
|
cv::Mat crop_img;
|
||||||
cv::Mat resize_img;
|
cv::Mat resize_img;
|
||||||
|
std::vector<float> width_list;
|
||||||
|
std::vector<cv::Mat> img_list;
|
||||||
|
|
||||||
|
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||||
|
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||||
|
if (cls != nullptr) {
|
||||||
|
crop_img = cls->Run(crop_img);
|
||||||
|
}
|
||||||
|
img_list.push_back(crop_img);
|
||||||
|
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||||
|
width_list.push_back(wh_ratio);
|
||||||
|
}
|
||||||
|
//sort box
|
||||||
|
vector<int> sort_index = argsort(width_list);
|
||||||
|
int batch_num1 = this->rec_batch_num_;//batchsize
|
||||||
std::cout << "The predicted text is :" << std::endl;
|
std::cout << "The predicted text is :" << std::endl;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
int beg_img_no = 0;
|
||||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
int end_img_no = 0;
|
||||||
if (cls != nullptr) {
|
for (int beg_img_no = 0; beg_img_no < img_list.size(); beg_img_no += batch_num1)
|
||||||
crop_img = cls->Run(crop_img);
|
{
|
||||||
|
float max_wh_ratio = 0;
|
||||||
|
end_img_no = min((int)boxes.size(), beg_img_no + batch_num1);
|
||||||
|
int batch_num = min(end_img_no - beg_img_no, batch_num1);
|
||||||
|
max_wh_ratio = width_list[sort_index[end_img_no - 1]];
|
||||||
|
int imgW1 = int(32 * max_wh_ratio);
|
||||||
|
int nqu, nra;
|
||||||
|
nqu = imgW1 / 4;
|
||||||
|
nra = imgW1 % 4;
|
||||||
|
int imgW = imgW1;
|
||||||
|
if (nra > 0)
|
||||||
|
{
|
||||||
|
imgW = int(4 * (nqu + 1));
|
||||||
}
|
}
|
||||||
|
std::vector<float> input(batch_num * 3 * 32 * imgW, 0.0f);//batchsize input
|
||||||
|
for (int i = beg_img_no; i < end_img_no; i++)
|
||||||
|
{
|
||||||
|
crop_img = img_list[sort_index[i]];
|
||||||
|
this->resize_op_.Run(crop_img, resize_img, max_wh_ratio);//resize
|
||||||
|
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||||
|
this->is_scale_);
|
||||||
|
|
||||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
cv::Mat padding_im;
|
||||||
|
cv::copyMakeBorder(resize_img, padding_im, 0, 0, 0, int(imgW - resize_img.cols), cv::BORDER_CONSTANT, { 0, 0, 0 });//padding image
|
||||||
|
|
||||||
this->resize_op_.Run(crop_img, resize_img, wh_ratio);
|
this->permute_op_.Run(&padding_im, input.data() + (i - beg_img_no) * 3 * padding_im.rows * padding_im.cols);
|
||||||
|
|
||||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
|
||||||
this->is_scale_);
|
|
||||||
|
|
||||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
|
||||||
|
|
||||||
this->permute_op_.Run(&resize_img, input.data());
|
|
||||||
|
|
||||||
// Inference.
|
|
||||||
if (this->use_zero_copy_run_) {
|
|
||||||
auto input_names = this->predictor_->GetInputNames();
|
|
||||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
|
||||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
|
||||||
input_t->copy_from_cpu(input.data());
|
|
||||||
this->predictor_->ZeroCopyRun();
|
|
||||||
} else {
|
|
||||||
paddle::PaddleTensor input_t;
|
|
||||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
|
||||||
input_t.data =
|
|
||||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
|
||||||
input_t.dtype = PaddleDType::FLOAT32;
|
|
||||||
std::vector<paddle::PaddleTensor> outputs;
|
|
||||||
this->predictor_->Run({input_t}, &outputs, 1);
|
|
||||||
}
|
}
|
||||||
|
auto input_names = this->predictor_->GetInputNames();
|
||||||
|
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||||
|
input_t->Reshape({ batch_num, 3, 32, imgW });
|
||||||
|
input_t->copy_from_cpu(input.data());
|
||||||
|
|
||||||
|
this->predictor_->ZeroCopyRun();
|
||||||
|
|
||||||
std::vector<int64_t> rec_idx;
|
std::vector<int64_t> rec_idx;
|
||||||
auto output_names = this->predictor_->GetOutputNames();
|
auto output_names = this->predictor_->GetOutputNames();
|
||||||
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||||
auto rec_idx_lod = output_t->lod();
|
auto rec_idx_lod = output_t->lod()[0];
|
||||||
auto shape_out = output_t->shape();
|
|
||||||
|
|
||||||
int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1,
|
|
||||||
std::multiplies<int>());
|
|
||||||
|
|
||||||
|
std::vector<int> output_shape = output_t->shape();
|
||||||
|
int out_num = 1;
|
||||||
|
for (int i = 0; i < output_shape.size(); ++i) {
|
||||||
|
out_num *= output_shape[i];
|
||||||
|
}
|
||||||
rec_idx.resize(out_num);
|
rec_idx.resize(out_num);
|
||||||
output_t->copy_to_cpu(rec_idx.data());
|
output_t->copy_to_cpu(rec_idx.data());//output data
|
||||||
|
|
||||||
std::vector<int> pred_idx;
|
|
||||||
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1]); n++) {
|
|
||||||
pred_idx.push_back(int(rec_idx[n]));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pred_idx.size() < 1e-3)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
index += 1;
|
|
||||||
std::cout << index << "\t";
|
|
||||||
for (int n = 0; n < pred_idx.size(); n++) {
|
|
||||||
std::cout << label_list_[pred_idx[n]];
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> predict_batch;
|
std::vector<float> predict_batch;
|
||||||
auto output_t_1 = this->predictor_->GetOutputTensor(output_names[1]);
|
auto output_t_1 = this->predictor_->GetOutputTensor(output_names[1]);
|
||||||
|
|
||||||
auto predict_lod = output_t_1->lod();
|
auto predict_lod = output_t_1->lod()[0];
|
||||||
auto predict_shape = output_t_1->shape();
|
auto predict_shape = output_t_1->shape();
|
||||||
int out_num_1 = std::accumulate(predict_shape.begin(), predict_shape.end(),
|
|
||||||
1, std::multiplies<int>());
|
int out_num_1 = 1;
|
||||||
|
for (int i = 0; i < predict_shape.size(); ++i) {
|
||||||
|
out_num_1 *= predict_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
predict_batch.resize(out_num_1);
|
predict_batch.resize(out_num_1);
|
||||||
output_t_1->copy_to_cpu(predict_batch.data());
|
output_t_1->copy_to_cpu(predict_batch.data());
|
||||||
|
|
||||||
int argmax_idx;
|
int argmax_idx;
|
||||||
int blank = predict_shape[1];
|
int blank = predict_shape[1];
|
||||||
float score = 0.f;
|
|
||||||
int count = 0;
|
|
||||||
float max_value = 0.0f;
|
|
||||||
|
|
||||||
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
|
for (int j = 0; j < rec_idx_lod.size() - 1; j++)
|
||||||
argmax_idx =
|
{
|
||||||
int(Utility::argmax(&predict_batch[n * predict_shape[1]],
|
std::vector<int> pred_idx;
|
||||||
&predict_batch[(n + 1) * predict_shape[1]]));
|
float score = 0.f;
|
||||||
max_value =
|
int count = 0;
|
||||||
float(*std::max_element(&predict_batch[n * predict_shape[1]],
|
float max_value = 0.0f;
|
||||||
&predict_batch[(n + 1) * predict_shape[1]]));
|
for (int n = int(rec_idx_lod[j]); n < int(rec_idx_lod[j + 1]); n++) {
|
||||||
if (blank - 1 - argmax_idx > 1e-5) {
|
pred_idx.push_back(int(rec_idx[n]));
|
||||||
score += max_value;
|
}
|
||||||
count += 1;
|
if (pred_idx.size() < 1e-3)
|
||||||
}
|
continue;
|
||||||
|
|
||||||
|
index += 1;
|
||||||
|
std::cout << index << "\t";
|
||||||
|
for (int n = 0; n < pred_idx.size(); n++) {
|
||||||
|
std::cout << label_list_[pred_idx[n]];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int n = predict_lod[j]; n < predict_lod[j + 1] - 1; n++) {
|
||||||
|
argmax_idx =
|
||||||
|
int(Utility::argmax(&predict_batch[n * predict_shape[1]],
|
||||||
|
&predict_batch[(n + 1) * predict_shape[1]]));
|
||||||
|
|
||||||
|
max_value = predict_batch[n * predict_shape[1] + argmax_idx];
|
||||||
|
if (blank - 1 - argmax_idx > 1e-5) {
|
||||||
|
score += max_value;
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
score /= count;
|
||||||
|
std::cout << "\tscore: " << score << std::endl;
|
||||||
}
|
}
|
||||||
score /= count;
|
|
||||||
std::cout << "\tscore: " << score << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ cls_thresh 0.9
|
||||||
# rec config
|
# rec config
|
||||||
rec_model_dir ./inference/rec_crnn
|
rec_model_dir ./inference/rec_crnn
|
||||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
||||||
|
rec_batch_num 30
|
||||||
|
|
||||||
# show the detection results
|
# show the detection results
|
||||||
visualize 1
|
visualize 1
|
||||||
|
|
Loading…
Reference in New Issue