Fix some issues in `direct_return` rule (#4783)

This commit is contained in:
Danny Mösch 2023-02-25 08:48:31 +01:00 committed by GitHub
parent 613e916c39
commit c9b1b961f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 3 deletions

View File

@ -299,6 +299,28 @@ extension Trivia {
var isSingleSpace: Bool { var isSingleSpace: Bool {
self == .spaces(1) self == .spaces(1)
} }
var withFirstEmptyLineRemoved: Trivia {
if let index = firstIndex(where: \.isNewline), index < endIndex {
return Trivia(pieces: dropFirst(index + 1))
}
return self
}
var withoutTrailingIndentation: Trivia {
Trivia(pieces: reversed().drop(while: \.isHorizontalWhitespace).reversed())
}
}
extension TriviaPiece {
var isHorizontalWhitespace: Bool {
switch self {
case .spaces, .tabs:
return true
default:
return false
}
}
} }
extension IntegerLiteralExprSyntax { extension IntegerLiteralExprSyntax {

View File

@ -132,16 +132,30 @@ struct DirectReturnRule: SwiftSyntaxCorrectableRule, ConfigurationProviderRule,
Example(""" Example("""
func f() -> Int { func f() -> Int {
{ _ in { _ in
// A comment
let b = 2 let b = 2
// Another comment
return b return b
}(1) }(1)
} }
"""): Example(""" """): Example("""
func f() -> Int { func f() -> Int {
{ _ in { _ in
// A comment
// Another comment
return 2 return 2
}(1) }(1)
} }
"""),
Example("""
func f() -> Bool {
let b : Bool = true
return b
}
"""): Example("""
func f() -> Bool {
return true as Bool
}
""") """)
] ]
) )
@ -197,10 +211,10 @@ private class Rewriter: SyntaxRewriter, ViolationsSyntaxRewriter {
override func visit(_ statements: CodeBlockItemListSyntax) -> CodeBlockItemListSyntax { override func visit(_ statements: CodeBlockItemListSyntax) -> CodeBlockItemListSyntax {
guard let (binding, returnStmt) = statements.violation, guard let (binding, returnStmt) = statements.violation,
!returnStmt.isContainedIn(regions: disabledRegions, locationConverter: locationConverter), !binding.isContainedIn(regions: disabledRegions, locationConverter: locationConverter),
let bindingList = binding.parent?.as(PatternBindingListSyntax.self), let bindingList = binding.parent?.as(PatternBindingListSyntax.self),
let varDecl = bindingList.parent?.as(VariableDeclSyntax.self), let varDecl = bindingList.parent?.as(VariableDeclSyntax.self),
let initExpression = binding.initializer?.value else { var initExpression = binding.initializer?.value else {
return super.visit(statements) return super.visit(statements)
} }
correctionPositions.append(binding.positionAfterSkippingLeadingTrivia) correctionPositions.append(binding.positionAfterSkippingLeadingTrivia)
@ -214,6 +228,15 @@ private class Rewriter: SyntaxRewriter, ViolationsSyntaxRewriter {
} }
return item return item
} }
if let type = binding.typeAnnotation?.type {
initExpression = ExprSyntax(
fromProtocol: AsExprSyntax(
expression: initExpression.trimmed,
asTok: .keyword(.as).with(\.leadingTrivia, .space).with(\.trailingTrivia, .space),
typeName: type.trimmed
)
)
}
if newBindingList.isNotEmpty { if newBindingList.isNotEmpty {
newStmtList.append(CodeBlockItemSyntax( newStmtList.append(CodeBlockItemSyntax(
item: .decl(DeclSyntax(varDecl.with(\.bindings, PatternBindingListSyntax(newBindingList)))) item: .decl(DeclSyntax(varDecl.with(\.bindings, PatternBindingListSyntax(newBindingList))))
@ -222,7 +245,9 @@ private class Rewriter: SyntaxRewriter, ViolationsSyntaxRewriter {
item: .stmt(StmtSyntax(returnStmt.with(\.expression, initExpression))) item: .stmt(StmtSyntax(returnStmt.with(\.expression, initExpression)))
)) ))
} else { } else {
let leadingTrivia = (binding.trailingTrivia ?? .zero) + (returnStmt.leadingTrivia ?? .zero) let leadingTrivia = (varDecl.leadingTrivia?.withoutTrailingIndentation ?? .zero)
+ (varDecl.trailingTrivia ?? .zero)
+ (returnStmt.leadingTrivia?.withFirstEmptyLineRemoved ?? .zero)
newStmtList.append( newStmtList.append(
CodeBlockItemSyntax( CodeBlockItemSyntax(
item: .stmt( item: .stmt(