/***************************************
 *                                     *
 *  JBoss: The OpenSource J2EE WebOS   *
 *                                     *
 *  Distributable under LGPL license.  *
 *  See terms of license at gnu.org.   *
 *                                     *
 ***************************************/

package org.jboss.spring.support;

import org.jboss.annotation.spring.Spring;
import org.jboss.logging.Logger;
import org.jboss.naming.Util;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.util.Assert;

import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.util.*;

/**
 * Injects objects from bean factory located in JNDI at jndiName gained
 * from @Spring annotation's field jndiName.
 * It is applied to setter methods and fields annotated with @Spring annotation.
 * @see MethodComparator Excludes overridden @Spring annotated methods
 * Class type check is performed before actual setting.
 *
 * @author <a href="mailto:ales.justin@genera-lynx.com">Ales Justin</a>
 */
public abstract class SpringInjectionSupport {

    protected Logger log = Logger.getLogger(getClass());
    private final Comparator<Method> METHOD_COMPARATOR = new MethodComparator();

    protected Object inject(Object target) throws Throwable {

        log.debug("Invoking Spring injection: " + target.getClass().getName());

        Method[] methods = getAllMethods(target);
        for (Method m : methods) {
            Spring spring = m.getAnnotation(Spring.class);
            if (spring != null) {
                if (isSetterMethod(m)) {
                    injectToMethod(target, m, spring);
                } else {
                    log.warn("Spring annotation only allowed on setter methods.");
                }
            }
        }

        Field[] fields = getAllFields(target);
        for (Field f : fields) {
            Spring spring = f.getAnnotation(Spring.class);
            if (spring != null) {
                injectToField(target, f, spring);
            }
        }

        return target;
    }

    protected Method[] getAllMethods(Object bean) {
        Class beanClass = bean.getClass();
        Set<Method> methods = new TreeSet<Method>(METHOD_COMPARATOR);
        while(beanClass != Object.class) {
            methods.addAll(Arrays.asList(beanClass.getDeclaredMethods()));
            beanClass = beanClass.getSuperclass();
        }
        return methods.toArray(new Method[methods.size()]);
    }

    protected Field[] getAllFields(Object bean) {
        Class beanClass = bean.getClass();
        List<Field> fields = new ArrayList<Field>();
        while(beanClass != Object.class) {
            fields.addAll(Arrays.asList(beanClass.getDeclaredFields()));
            beanClass = beanClass.getSuperclass();
        }
        return fields.toArray(new Field[fields.size()]);
    }

    private boolean isSetterMethod(Method m) {
        return m.getName().startsWith("set") && m.getParameterTypes().length == 1;
    }

    private Object getObjectFromBeanFactory(Spring spring) throws Exception {
        BeanFactory beanFactory = (BeanFactory)Util.lookup(spring.jndiName(), BeanFactory.class);
        return beanFactory.getBean(spring.bean());
    }

    private void injectToMethod(Object target, Method method, Spring spring) throws Exception {
        Object bean = getObjectFromBeanFactory(spring);
        doAssert(bean, method.getParameterTypes()[0]);
        logInjection(spring, bean, target, method);
        method.setAccessible(true);
        method.invoke(target, bean);
    }

    private void injectToField(Object target, Field field, Spring spring) throws Exception {
        Object bean = getObjectFromBeanFactory(spring);
        doAssert(bean, field.getType());
        logInjection(spring, bean, target, field);
        field.setAccessible(true);
        field.set(target, bean);
    }

    private void doAssert(Object bean, Class expectedBeanClass) {
        Assert.isTrue(expectedBeanClass.isAssignableFrom(bean.getClass()),
                      "Illegal bean class type - " + bean.getClass().getName() +
                      " - cannot be assigned to: " + expectedBeanClass.getName());
    }

    private void logInjection(Spring spring, Object bean, Object target, Member m) {
        log.debug("Injecting bean '" + spring.bean() + "' of class type " +
                  bean.getClass().getName() + " into " + target + " via " + m);
    }

    /**
     * Equals on overridden methods.
     * Any other solution?
     */
    private class MethodComparator implements Comparator<Method> {

        public int compare(Method m1, Method m2) {
            String name1 = m1.getName();
            String name2 = m2.getName();

            if (name1.equals(name2)) {
                Class returnType1 = m1.getReturnType();
                Class returnType2 = m2.getReturnType();
                Class[] params1 = m1.getParameterTypes();
                Class[] params2 = m1.getParameterTypes();
                if (params1.length == params2.length) {
                    if (returnType1.equals(returnType2)) {
                        int i;
                        int length = params1.length;
                        for (i = 0; i < length; i++) {
                            if (!params1[i].equals(params2[i])) {
                                break;
                            }
                        }
                        //not equal
                        if (i < length) {
                            return params1[i].getName().compareTo(params2[i].getName());
                        } else {
                            log.warn("Found overridden @Spring annotated method: " + m2);
                            return 0; //overridden method
                        }
                    } else {
                        return returnType1.getName().compareTo(returnType2.getName());
                    }
                } else {
                    return params1.length - params2.length;
                }
            } else {
                return name1.compareTo(name2);
            }
        }

    }
}
