PaddleDetection/deploy/python/clrnet_postprocess.py

263 lines
8.9 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from scipy.special import softmax
from scipy.interpolate import InterpolatedUnivariateSpline
def line_iou(pred, target, img_w, length=15, aligned=True):
'''
Calculate the line iou value between predictions and targets
Args:
pred: lane predictions, shape: (num_pred, 72)
target: ground truth, shape: (num_target, 72)
img_w: image width
length: extended radius
aligned: True for iou loss calculation, False for pair-wise ious in assign
'''
px1 = pred - length
px2 = pred + length
tx1 = target - length
tx2 = target + length
if aligned:
invalid_mask = target
ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
else:
num_pred = pred.shape[0]
invalid_mask = target.tile([num_pred, 1, 1])
ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
px1[:, None, :], tx1[None, ...]))
union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
paddle.minimum(px1[:, None, :], tx1[None, ...]))
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
ovr[invalid_masks] = 0.
union[invalid_masks] = 0.
iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
return iou
class Lane:
def __init__(self, points=None, invalid_value=-2., metadata=None):
super(Lane, self).__init__()
self.curr_iter = 0
self.points = points
self.invalid_value = invalid_value
self.function = InterpolatedUnivariateSpline(
points[:, 1], points[:, 0], k=min(3, len(points) - 1))
self.min_y = points[:, 1].min() - 0.01
self.max_y = points[:, 1].max() + 0.01
self.metadata = metadata or {}
def __repr__(self):
return '[Lane]\n' + str(self.points) + '\n[/Lane]'
def __call__(self, lane_ys):
lane_xs = self.function(lane_ys)
lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
)] = self.invalid_value
return lane_xs
def to_array(self, sample_y_range, img_w, img_h):
self.sample_y = range(sample_y_range[0], sample_y_range[1],
sample_y_range[2])
sample_y = self.sample_y
img_w, img_h = img_w, img_h
ys = np.array(sample_y) / float(img_h)
xs = self(ys)
valid_mask = (xs >= 0) & (xs < 1)
lane_xs = xs[valid_mask] * img_w
lane_ys = ys[valid_mask] * img_h
lane = np.concatenate(
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
return lane
def __iter__(self):
return self
def __next__(self):
if self.curr_iter < len(self.points):
self.curr_iter += 1
return self.points[self.curr_iter - 1]
self.curr_iter = 0
raise StopIteration
class CLRNetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
max_lanes, num_points):
self.img_w = img_w
self.conf_threshold = conf_threshold
self.nms_thres = nms_thres
self.max_lanes = max_lanes
self.num_points = num_points
self.n_strips = num_points - 1
self.n_offsets = num_points
self.ori_img_h = ori_img_h
self.cut_height = cut_height
self.prior_ys = paddle.linspace(
start=1, stop=0, num=self.n_offsets).astype('float64')
def predictions_to_pred(self, predictions):
"""
Convert predictions to internal Lane structure for evaluation.
"""
lanes = []
for lane in predictions:
lane_xs = lane[6:].clone()
start = min(
max(0, int(round(lane[2].item() * self.n_strips))),
self.n_strips)
length = int(round(lane[5].item()))
end = start + length - 1
end = min(end, len(self.prior_ys) - 1)
if start > 0:
mask = ((lane_xs[:start] >= 0.) &
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
mask = ~((mask.cumprod()[::-1]).astype(np.bool_))
lane_xs[:start][mask] = -2
if end < len(self.prior_ys) - 1:
lane_xs[end + 1:] = -2
lane_ys = self.prior_ys[lane_xs >= 0].clone()
lane_xs = lane_xs[lane_xs >= 0]
lane_xs = lane_xs.flip(axis=0).astype('float64')
lane_ys = lane_ys.flip(axis=0)
lane_ys = (lane_ys *
(self.ori_img_h - self.cut_height) + self.cut_height
) / self.ori_img_h
if len(lane_xs) <= 1:
continue
points = paddle.stack(
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
axis=1).squeeze(axis=2)
lane = Lane(
points=points.cpu().numpy(),
metadata={
'start_x': lane[3],
'start_y': lane[2],
'conf': lane[1]
})
lanes.append(lane)
return lanes
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
"""
NMS for lane detection.
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
scores: paddle.Tensor [num_lanes]
nms_overlap_thresh: float
top_k: int
"""
# sort by scores to get idx
idx = scores.argsort(descending=True)
keep = []
condidates = predictions.clone()
condidates = condidates.index_select(idx)
while len(condidates) > 0:
keep.append(idx[0])
if len(keep) >= top_k or len(condidates) == 1:
break
ious = []
for i in range(1, len(condidates)):
ious.append(1 - line_iou(
condidates[i].unsqueeze(0),
condidates[0].unsqueeze(0),
img_w=self.img_w,
length=15))
ious = paddle.to_tensor(ious)
mask = ious <= nms_overlap_thresh
id = paddle.where(mask == False)[0]
if id.shape[0] == 0:
break
condidates = condidates[1:].index_select(id)
idx = idx[1:].index_select(id)
keep = paddle.stack(keep)
return keep
def get_lanes(self, output, as_lanes=True):
"""
Convert model output to lanes.
"""
softmax = nn.Softmax(axis=1)
decoded = []
for predictions in output:
if len(predictions) == 0:
decoded.append([])
continue
threshold = self.conf_threshold
scores = softmax(predictions[:, :2])[:, 1]
keep_inds = scores >= threshold
predictions = predictions[keep_inds]
scores = scores[keep_inds]
if predictions.shape[0] == 0:
decoded.append([])
continue
nms_predictions = predictions.detach().clone()
nms_predictions = paddle.concat(
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
self.img_w - 1)
keep = self.lane_nms(
nms_predictions[..., 5:],
scores,
nms_overlap_thresh=self.nms_thres,
top_k=self.max_lanes)
predictions = predictions.index_select(keep)
if predictions.shape[0] == 0:
decoded.append([])
continue
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
if as_lanes:
pred = self.predictions_to_pred(predictions)
else:
pred = predictions
decoded.append(pred)
return decoded
def __call__(self, lanes_list):
lanes = self.get_lanes(lanes_list)
return lanes