How to create JUnit assertions from running applications

JUnit is a cool and powerful framework for unit test creation. But there is a drawback: we need to create a lot of assertions! Without assertions, we can create a very good path coverage, but it is not guaranteed that the code works correctly. This can be really cumbersome while we are working with big JavaBeans or Entities. Check for instance the following piece of code:

public class BeanTest {

    @Test
    public void testCreate() {
        ClassToTest theCandidate = new ClassToTest();
        TestBean theTestBean = theCandidate.create();

        // Here we need to generate assertions?
        System.out.println(theTestBean);
    }
}

Wouldn’t it be cool to just execute it and generate the unit test assertions from a running application instance, so we can get the following?

public class BeanTest {

    @Test
    public void testCreate() {
        ClassToTest theCandidate = new ClassToTest();
        TestBean theTestBean = theCandidate.create();

        // Here we need to generate assertions?
        System.out.println(theTestBean);

        Assert.assertEquals("Name1", theTestBean.getName1());
        Assert.assertNull(theTestBean.getName2());
        Assert.assertEquals(Integer.valueOf(1), theTestBean.getValue1());
        Assert.assertEquals(Boolean.valueOf(true), theTestBean.getValue2());
        Assert.assertEquals(Double.valueOf(3.0d), theTestBean.getValue3());
    }
}

So how do we create the missing unit test assertions? Just step into the unit test using a debugger and evaluate the following expression in the shell:

AssertGen.generate(theTestBean)

And you will get the missing assertion statements. Now we can copy them to the original unit test source and rerun it again. Pretty cool, this can save a lot of time while writing unit test. But how does it work under the hood? The AssertGen class is quite simple, just check the source:

public class AssertGen {

    public static String generate(Object aObject) throws InvocationTargetException, IllegalAccessException {
        StringWriter theStringWriter = new StringWriter();
        PrintWriter thePW = new PrintWriter(theStringWriter);
        if (aObject instanceof List) {
            List theList = (List) aObject;
            thePW.println("Assert.assertEquals(" + theList.size() + ", ob.size());");
            int row = 0;
            for (Object theObject : theList) {
                generate(thePW, "ob.get(" + row + ").", theObject);
                row++;
            }
        } else {
            generate(thePW, "ob.", aObject);
        }
        thePW.flush();
        return theStringWriter.toString();
    }

    private static void generate(PrintWriter aPW, String aPrefix, Object aObject) throws InvocationTargetException,
            IllegalAccessException {
        for (Method theMethod : aObject.getClass().getMethods()) {
            if (Modifier.isPublic(theMethod.getModifiers()) && theMethod.getParameterTypes().length == 0) {
                String theMethodName = theMethod.getName();
                if (!theMethodName.equals("getClass") && theMethodName.startsWith("get") || (theMethodName.startsWith("is"))) {
                    Object theValue = theMethod.invoke(aObject);
                    if (theValue == null) {
                        aPW.println("Assert.assertNull(" + aPrefix + theMethod.getName() + "());");
                    } else {
                        aPW.println("Assert.assertEquals(" + toAssertionText(theValue) + "," + aPrefix
                                + theMethod.getName() + "());");
                    }
                }
            }
        }
    }

    private static String toAssertionText(Object aValue) {
        if (aValue instanceof String) {
            return "\"" + aValue + "\"";
        }
        if (aValue instanceof Integer) {
            return "Integer.valueOf(" + aValue + ")";
        }
        if (aValue instanceof Long) {
            return "Long.valueOf(" + aValue + "L)";
        }
        if (aValue instanceof Boolean) {
            return "Boolean.valueOf(" + aValue + ")";
        }
        if (aValue instanceof Double) {
            return "Double.valueOf(" + aValue + "d)";
        }
        if (aValue instanceof Timestamp) {
            Timestamp theTS = (Timestamp) aValue;
            return "new Timestamp(" + theTS.getTime() + "L)";
        }

        return "don't know to convert : " + aValue;
    }
}

Nice, i really love JUnit!

comments powered by Disqus