diff --git a/libs/java-ast/src/lib/core/Writer.spec.ts b/libs/java-ast/src/lib/core/Writer.spec.ts index 3043cd7..af87e0c 100644 --- a/libs/java-ast/src/lib/core/Writer.spec.ts +++ b/libs/java-ast/src/lib/core/Writer.spec.ts @@ -152,4 +152,20 @@ describe("Writer", () => { const formatted = writer.formatCode("public void method() {"); expect(formatted).toBe("public void method() {"); }); + + it("should correctly de-duplicate imports and references", () => { + writer.addReference( + new ClassReference({ name: "ArrayList", packageName: "java.util" }), + ); + writer.addReference( + new ClassReference({ name: "ArrayList", packageName: "java.util" }), + ); + writer.addImport("java.util.ArrayList"); + writer.addImport("java.util.ArrayList"); + + const output = writer.toString(); + expect(output.trim()).toBe(`package com.example; + +import java.util.ArrayList;`); + }); }); diff --git a/libs/java-ast/src/lib/core/Writer.ts b/libs/java-ast/src/lib/core/Writer.ts index 1eb7716..0155564 100644 --- a/libs/java-ast/src/lib/core/Writer.ts +++ b/libs/java-ast/src/lib/core/Writer.ts @@ -26,7 +26,7 @@ export class Writer implements IWriter { private indentLevel = 0; private buffer = ""; private imports: Set = new Set(); - private references: Set = new Set(); + private references: Set = new Set(); private packageName: string; private skipPackageDeclaration: boolean; @@ -111,7 +111,9 @@ export class Writer implements IWriter { if (reference.packageName === this.packageName) { return; } - this.references.add(reference); + // Sets compare objects by reference, so we need to + // convert to string before adding + this.references.add(`${reference.packageName}.${reference.name}`); } /** @@ -155,11 +157,10 @@ export class Writer implements IWriter { result += `package ${this.packageName};\n\n`; } - // Write imports - const allImports = [...this.imports]; - this.references.forEach((ref) => { - allImports.push(`${ref.packageName}.${ref.name}`); - }); + // Combine imports with references, and remove duplicates + const allImports = [ + ...new Set([...this.imports, ...this.references]), + ]; // Sort imports allImports.sort().forEach((importName) => { diff --git a/libs/java-ast/tests/__snapshots__/ComplexExample.spec.ts.snap b/libs/java-ast/tests/__snapshots__/ComplexExample.spec.ts.snap index 8ded061..cc76c1e 100644 --- a/libs/java-ast/tests/__snapshots__/ComplexExample.spec.ts.snap +++ b/libs/java-ast/tests/__snapshots__/ComplexExample.spec.ts.snap @@ -3,37 +3,23 @@ exports[`Complex Example should generate a complex Spring Boot controller 1`] = ` "package com.example.controller; -import com.example.dto.UserDTO; -import com.example.dto.UserDTO; -import com.example.dto.UserDTO; -import com.example.dto.UserDTO; -import com.example.dto.UserDTO; -import com.example.dto.UserDTO; import com.example.dto.UserDTO; import com.example.exception.DuplicateResourceException; import com.example.service.UserService; -import com.example.service.UserService; import io.swagger.v3.oas.annotations.Operation; import java.lang.String; import java.net.URI; import java.util.HashMap; import java.util.List; -import java.util.List; -import java.util.Map; import java.util.Map; import javax.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.Autowired; import org.springframework.http.HttpStatus; -import org.springframework.http.HttpStatus; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; -import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.MethodArgumentNotValidException.MethodArgumentNotValidException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody;