263 lines
8.9 KiB
Python
263 lines
8.9 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn as nn
|
|
from scipy.special import softmax
|
|
from scipy.interpolate import InterpolatedUnivariateSpline
|
|
|
|
|
|
def line_iou(pred, target, img_w, length=15, aligned=True):
|
|
'''
|
|
Calculate the line iou value between predictions and targets
|
|
Args:
|
|
pred: lane predictions, shape: (num_pred, 72)
|
|
target: ground truth, shape: (num_target, 72)
|
|
img_w: image width
|
|
length: extended radius
|
|
aligned: True for iou loss calculation, False for pair-wise ious in assign
|
|
'''
|
|
px1 = pred - length
|
|
px2 = pred + length
|
|
tx1 = target - length
|
|
tx2 = target + length
|
|
|
|
if aligned:
|
|
invalid_mask = target
|
|
ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
|
|
union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
|
|
else:
|
|
num_pred = pred.shape[0]
|
|
invalid_mask = target.tile([num_pred, 1, 1])
|
|
|
|
ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
|
|
px1[:, None, :], tx1[None, ...]))
|
|
union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
|
|
paddle.minimum(px1[:, None, :], tx1[None, ...]))
|
|
|
|
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
|
|
|
|
ovr[invalid_masks] = 0.
|
|
union[invalid_masks] = 0.
|
|
iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
|
|
return iou
|
|
|
|
|
|
class Lane:
|
|
def __init__(self, points=None, invalid_value=-2., metadata=None):
|
|
super(Lane, self).__init__()
|
|
self.curr_iter = 0
|
|
self.points = points
|
|
self.invalid_value = invalid_value
|
|
self.function = InterpolatedUnivariateSpline(
|
|
points[:, 1], points[:, 0], k=min(3, len(points) - 1))
|
|
self.min_y = points[:, 1].min() - 0.01
|
|
self.max_y = points[:, 1].max() + 0.01
|
|
self.metadata = metadata or {}
|
|
|
|
def __repr__(self):
|
|
return '[Lane]\n' + str(self.points) + '\n[/Lane]'
|
|
|
|
def __call__(self, lane_ys):
|
|
lane_xs = self.function(lane_ys)
|
|
|
|
lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
|
|
)] = self.invalid_value
|
|
return lane_xs
|
|
|
|
def to_array(self, sample_y_range, img_w, img_h):
|
|
self.sample_y = range(sample_y_range[0], sample_y_range[1],
|
|
sample_y_range[2])
|
|
sample_y = self.sample_y
|
|
img_w, img_h = img_w, img_h
|
|
ys = np.array(sample_y) / float(img_h)
|
|
xs = self(ys)
|
|
valid_mask = (xs >= 0) & (xs < 1)
|
|
lane_xs = xs[valid_mask] * img_w
|
|
lane_ys = ys[valid_mask] * img_h
|
|
lane = np.concatenate(
|
|
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
|
|
return lane
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.curr_iter < len(self.points):
|
|
self.curr_iter += 1
|
|
return self.points[self.curr_iter - 1]
|
|
self.curr_iter = 0
|
|
raise StopIteration
|
|
|
|
|
|
class CLRNetPostProcess(object):
|
|
"""
|
|
Args:
|
|
input_shape (int): network input image size
|
|
ori_shape (int): ori image shape of before padding
|
|
scale_factor (float): scale factor of ori image
|
|
enable_mkldnn (bool): whether to open MKLDNN
|
|
"""
|
|
|
|
def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
|
|
max_lanes, num_points):
|
|
self.img_w = img_w
|
|
self.conf_threshold = conf_threshold
|
|
self.nms_thres = nms_thres
|
|
self.max_lanes = max_lanes
|
|
self.num_points = num_points
|
|
self.n_strips = num_points - 1
|
|
self.n_offsets = num_points
|
|
self.ori_img_h = ori_img_h
|
|
self.cut_height = cut_height
|
|
|
|
self.prior_ys = paddle.linspace(
|
|
start=1, stop=0, num=self.n_offsets).astype('float64')
|
|
|
|
def predictions_to_pred(self, predictions):
|
|
"""
|
|
Convert predictions to internal Lane structure for evaluation.
|
|
"""
|
|
lanes = []
|
|
for lane in predictions:
|
|
lane_xs = lane[6:].clone()
|
|
start = min(
|
|
max(0, int(round(lane[2].item() * self.n_strips))),
|
|
self.n_strips)
|
|
length = int(round(lane[5].item()))
|
|
end = start + length - 1
|
|
end = min(end, len(self.prior_ys) - 1)
|
|
if start > 0:
|
|
mask = ((lane_xs[:start] >= 0.) &
|
|
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
|
|
mask = ~((mask.cumprod()[::-1]).astype(np.bool_))
|
|
lane_xs[:start][mask] = -2
|
|
if end < len(self.prior_ys) - 1:
|
|
lane_xs[end + 1:] = -2
|
|
|
|
lane_ys = self.prior_ys[lane_xs >= 0].clone()
|
|
lane_xs = lane_xs[lane_xs >= 0]
|
|
lane_xs = lane_xs.flip(axis=0).astype('float64')
|
|
lane_ys = lane_ys.flip(axis=0)
|
|
|
|
lane_ys = (lane_ys *
|
|
(self.ori_img_h - self.cut_height) + self.cut_height
|
|
) / self.ori_img_h
|
|
if len(lane_xs) <= 1:
|
|
continue
|
|
points = paddle.stack(
|
|
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
|
|
axis=1).squeeze(axis=2)
|
|
lane = Lane(
|
|
points=points.cpu().numpy(),
|
|
metadata={
|
|
'start_x': lane[3],
|
|
'start_y': lane[2],
|
|
'conf': lane[1]
|
|
})
|
|
lanes.append(lane)
|
|
return lanes
|
|
|
|
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
|
|
"""
|
|
NMS for lane detection.
|
|
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
|
|
scores: paddle.Tensor [num_lanes]
|
|
nms_overlap_thresh: float
|
|
top_k: int
|
|
"""
|
|
# sort by scores to get idx
|
|
idx = scores.argsort(descending=True)
|
|
keep = []
|
|
|
|
condidates = predictions.clone()
|
|
condidates = condidates.index_select(idx)
|
|
|
|
while len(condidates) > 0:
|
|
keep.append(idx[0])
|
|
if len(keep) >= top_k or len(condidates) == 1:
|
|
break
|
|
|
|
ious = []
|
|
for i in range(1, len(condidates)):
|
|
ious.append(1 - line_iou(
|
|
condidates[i].unsqueeze(0),
|
|
condidates[0].unsqueeze(0),
|
|
img_w=self.img_w,
|
|
length=15))
|
|
ious = paddle.to_tensor(ious)
|
|
|
|
mask = ious <= nms_overlap_thresh
|
|
id = paddle.where(mask == False)[0]
|
|
|
|
if id.shape[0] == 0:
|
|
break
|
|
condidates = condidates[1:].index_select(id)
|
|
idx = idx[1:].index_select(id)
|
|
keep = paddle.stack(keep)
|
|
|
|
return keep
|
|
|
|
def get_lanes(self, output, as_lanes=True):
|
|
"""
|
|
Convert model output to lanes.
|
|
"""
|
|
softmax = nn.Softmax(axis=1)
|
|
decoded = []
|
|
|
|
for predictions in output:
|
|
if len(predictions) == 0:
|
|
decoded.append([])
|
|
continue
|
|
threshold = self.conf_threshold
|
|
scores = softmax(predictions[:, :2])[:, 1]
|
|
keep_inds = scores >= threshold
|
|
predictions = predictions[keep_inds]
|
|
scores = scores[keep_inds]
|
|
|
|
if predictions.shape[0] == 0:
|
|
decoded.append([])
|
|
continue
|
|
nms_predictions = predictions.detach().clone()
|
|
nms_predictions = paddle.concat(
|
|
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
|
|
|
|
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
|
|
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
|
|
self.img_w - 1)
|
|
|
|
keep = self.lane_nms(
|
|
nms_predictions[..., 5:],
|
|
scores,
|
|
nms_overlap_thresh=self.nms_thres,
|
|
top_k=self.max_lanes)
|
|
|
|
predictions = predictions.index_select(keep)
|
|
|
|
if predictions.shape[0] == 0:
|
|
decoded.append([])
|
|
continue
|
|
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
|
|
if as_lanes:
|
|
pred = self.predictions_to_pred(predictions)
|
|
else:
|
|
pred = predictions
|
|
decoded.append(pred)
|
|
return decoded
|
|
|
|
def __call__(self, lanes_list):
|
|
lanes = self.get_lanes(lanes_list)
|
|
return lanes
|