/*
  * JBoss, Home of Professional Open Source
  * Copyright 2005, JBoss Inc., and individual contributors as indicated
  * by the @authors tag. See the copyright.txt in the distribution for a
  * full listing of individual contributors.
  *
  * This is free software; you can redistribute it and/or modify it
  * under the terms of the GNU Lesser General Public License as
  * published by the Free Software Foundation; either version 2.1 of
  * the License, or (at your option) any later version.
  *
  * This software is distributed in the hope that it will be useful,
  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  * Lesser General Public License for more details.
  *
  * You should have received a copy of the GNU Lesser General Public
  * License along with this software; if not, write to the Free
  * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
  * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
  */
package org.jboss.ejb3.injection;

import org.hibernate.SessionFactory;
import org.jboss.ejb3.Container;
import org.jboss.ejb3.EJBContainer;
import org.jboss.ejb3.dd.Injectable;
import org.jboss.ejb3.dd.PersistenceUnitRef;
import org.jboss.ejb3.entity.InjectedEntityManagerFactory;
import org.jboss.ejb3.entity.InjectedSessionFactory;
import org.jboss.ejb3.entity.ManagedEntityManagerFactory;
import org.jboss.ejb3.entity.PersistenceUnitDeployment;
import org.jboss.ejb3.interceptor.InterceptorInjector;
import org.jboss.naming.Util;
import org.jboss.logging.Logger;

import javax.naming.NameNotFoundException;
import javax.naming.NamingException;
import javax.persistence.EntityManagerFactory;
import javax.persistence.PersistenceUnit;
import javax.persistence.PersistenceUnits;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

/**
 * Searches bean class for all @Inject and create Injectors
 *
 * @author <a href="mailto:bill@jboss.org">Bill Burke</a>
 * @version $Revision$
 */
public class PersistenceUnitHandler
{
   private static final Logger log = Logger
           .getLogger(PersistenceContextHandler.class);

   private static void loadXmlPUDependencies(List<PersistenceUnitRef> refs, EJBContainer container)
   {
      for (PersistenceUnitRef ref : refs)
      {
         container.getPuEncXmlEntries().put("env/" + ref.getRefName(), ref.getUnitName());
         try
         {
            addPUDependency(ref.getUnitName(), container);
         }
         catch (NameNotFoundException e)
         {
            throw new RuntimeException("Illegal <persistence-unit-ref> of " + ref.getRefName() + " :" + e.getMessage());
         }
      }
   }

   private static void loadPersistenceUnitRefXml(List<PersistenceUnitRef> refs, EJBContainer container, Class clazz, HashMap<AccessibleObject, Injector> injectors)
   {

      for (PersistenceUnitRef ref : refs)
      {
         Class injectionType = null;
         if (ref.getInjectionTarget() != null)
         {
            // todo, get injection target class
            AccessibleObject ao = InjectionUtil.findInjectionTarget(clazz, ref.getInjectionTarget());
            if (ao instanceof Field)
            {
               injectionType = ((Field) ao).getType();
               injectors.put(ao, new JndiFieldInjector((Field) ao, "env/" + ref.getRefName(), container.getEnc()));
            }
            else
            {
               injectionType = ((Method) ao).getParameterTypes()[0];
               injectors.put(ao, new JndiMethodInjector((java.lang.reflect.Method) ao, "env/" + ref.getRefName(), container.getEnc()));
            }
         }
         Object factory = null;
         try
         {
            factory = getFactory(injectionType, ref.getUnitName(), container);
         }
         catch (NameNotFoundException e)
         {
            throw new RuntimeException(e);
         }
         if (factory == null)
         {
            throw new RuntimeException("Failed to load <persistence-unit-ref> of unit name: " + ref.getUnitName() + " for EJB " + container.getEjbName());
         }

         try
         {
            Util.bind(container.getEnc(), "env/" + ref.getRefName(), factory);
         }
         catch (Exception e)
         {
            throw new RuntimeException("Failed to load <persistence-unit-ref> of unit name: " + ref.getUnitName() + " ref-name" + ref.getRefName() + container.getEjbName(), e);
         }
      }
   }


   private static void loadPUsAnnotation(EJBContainer container, Class clazz, boolean isContainer) throws Exception
   {
      PersistenceUnits resources = (PersistenceUnits) InjectionUtil.getAnnotation(
              PersistenceUnits.class, container, clazz, isContainer);
      if (resources != null)
      {
         for (PersistenceUnit ref : resources.value())
         {
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               throw new RuntimeException("JBoss requires name() for class level @PersistenceUnit");
            }
            // skip if already loaded by XML
            if (container.getPuEncXmlEntries().containsKey("env/" + encName)) continue;

            EntityManagerFactory factory = getEntityManagerFactory(ref, container);
            bindEntityManagerFactoryToEnc(container, "env/" + encName, factory);
         }
      }
      PersistenceUnit pu = (PersistenceUnit)
              InjectionUtil.getAnnotation(PersistenceUnit.class, container, clazz, isContainer);
      if (pu != null)
      {
         String encName = pu.name();
         if (encName == null || encName.equals(""))
         {
            throw new RuntimeException("JBoss requires name() for class level @PersistenceUnit");
         }
         // skip if already loaded by XML
         if (container.getPuEncXmlEntries().containsKey("env/" + encName)) return;

         EntityManagerFactory factory = getEntityManagerFactory(pu, container);
         if (factory == null)
            throw new RuntimeException("Unable to find @PersistenceUnit: " + pu.unitName() + " on " + clazz.getName());
         bindEntityManagerFactoryToEnc(container, "env/" + encName, factory);
      }
   }

   private static void loadClassPUDependencies(EJBContainer container, Class clazz, boolean isContainer)
   {
      PersistenceUnits resources = (PersistenceUnits) InjectionUtil.getAnnotation(
              PersistenceUnits.class, container, clazz, isContainer);
      if (resources != null)
      {
         for (PersistenceUnit ref : resources.value())
         {
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               throw new RuntimeException("JBoss requires name() for class level @PersistenceUnit");
            }
            // skip if already loaded by XML
            if (container.getPuEncXmlEntries().containsKey("env/" + encName)) continue;
            try
            {
               addPUDependency(ref.unitName(), container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Invalid " + clazz.getName() + " annotation @PersistenceUnit: " + e.getMessage());
            }
         }
      }
      PersistenceUnit pu = (PersistenceUnit)
              InjectionUtil.getAnnotation(PersistenceUnit.class, container, clazz, isContainer);
      if (pu != null)
      {
         String encName = pu.name();
         if (encName == null || encName.equals(""))
         {
            throw new RuntimeException("JBoss requires name() for class level @PersistenceUnit");
         }
         // skip if already loaded by XML
         if (container.getPuEncXmlEntries().containsKey("env/" + encName)) return;
         try
         {
            addPUDependency(pu.unitName(), container);
         }
         catch (NameNotFoundException e)
         {
            throw new RuntimeException("Invalid " + clazz.getName() + " annotation @PersistenceUnit: " + e.getMessage());
         }
      }
   }

   public static void addPUDependency(String unitName, EJBContainer container) throws NameNotFoundException
   {
      PersistenceUnitDeployment deployment = null;
      // look in EAR first
      deployment = container.getDeployment().getPersistenceUnitDeployment(unitName);
      if (deployment != null)
      {
         container.getDependencyPolicy().addDependency(deployment.getKernelName());
         return;
      }
      // probably not deployed yet.
      container.getDependencyPolicy().addDependency(PersistenceUnitDeployment.getDefaultKernelName(unitName));
   }

   public static ManagedEntityManagerFactory getManagedEntityManagerFactory(EJBContainer container, String unitName)
           throws NameNotFoundException
   {
      ManagedEntityManagerFactory factory;
      PersistenceUnitDeployment deployment = ((EJBContainer) container).getDeployment().getPersistenceUnitDeployment(unitName);
      if (deployment != null)
      {
         factory = deployment.getManagedFactory();
      }
      else
      {
         throw new NameNotFoundException("Unable to find persistence unit: " + unitName + " for EJB container: " + container.getObjectName());
      }
      return factory;
   }


   private static EntityManagerFactory getEntityManagerFactory(PersistenceUnit ref, EJBContainer container) throws NameNotFoundException
   {
      return getEntityManagerFactory(ref.unitName(), container);
   }

   private static Object getFactory(Class type, String unitName, EJBContainer container) throws NameNotFoundException
   {
      if (type != null && type.getName().equals(SessionFactory.class.getName()))
         return getSessionFactory(unitName, container);
      return getEntityManagerFactory(unitName, container);
   }

   private static EntityManagerFactory getEntityManagerFactory(String unitName, EJBContainer container) throws NameNotFoundException
   {
      ManagedEntityManagerFactory managedFactory;
      PersistenceUnitDeployment deployment = container.getDeployment().getPersistenceUnitDeployment(unitName);
      if (deployment != null)
      {
         managedFactory = deployment.getManagedFactory();
      }
      else
      {
         return null;
      }
      return new InjectedEntityManagerFactory(managedFactory);
   }


   private static SessionFactory getSessionFactory(String ref, EJBContainer container) throws NameNotFoundException
   {
      ManagedEntityManagerFactory managedFactory;
      PersistenceUnitDeployment deployment = ((EJBContainer) container).getDeployment().getPersistenceUnitDeployment(ref);
      if (deployment != null)
      {
         managedFactory = deployment.getManagedFactory();
      }
      else
      {
         return null;
      }
      return new InjectedSessionFactory(managedFactory);
   }

   public static void loadInjectors(EJBContainer ejb) throws Exception
   {
      Class clazz = ejb.getBeanClass();
      loadInjectors(clazz, ejb.getXml(), ejb, ejb.getEncInjections(), true);
   }

   public static void loadInjectors(InterceptorInjector injector) throws Exception
   {
      loadInjectors(injector.getClazz(), injector.getXml(), (EJBContainer) injector.getContainer(), injector.getEncInjections(), false);
   }

   private static void loadInjectors(Class clazz, Injectable xml, EJBContainer container, HashMap<AccessibleObject, Injector> encInjections, boolean isContainer) throws Exception
   {
      if (xml != null)
      {
         loadPersistenceUnitRefXml(xml.getPersistenceUnitRefs(), container, clazz, encInjections);
      }

      HashSet<String> visitedMethods = new HashSet<String>();

      loadPUsAnnotation(container, clazz, isContainer);
      loadMethodInjectors(visitedMethods, container, clazz, encInjections, isContainer);
      loadFieldInjectors(clazz, container, encInjections, isContainer);
   }

   public static void loadDependencies(Injectable xml, EJBContainer container, Class clazz, boolean isContainer)
   {
      if (xml != null)
      {
         loadXmlPUDependencies(xml.getPersistenceUnitRefs(), container);
      }
      loadClassPUDependencies(container, clazz, isContainer);
      HashSet<String> visitedMethods = new HashSet<String>();
      loadMethodDependencies(visitedMethods, container, clazz, isContainer);
      loadFieldDependencies(clazz, container, isContainer);

   }

   private static void bindEntityManagerFactoryToEnc(Container container, String encName, Object factory)
           throws NamingException
   {
      try
      {
         Util.bind(container.getEnc(), encName, factory);
      }
      catch (Exception e)
      {
         NamingException namingException = new NamingException("Could not bind entity manager factory for EJB container with ejb name " + container.getEjbName() + " into JNDI under jndiName: " + container.getEnc().getNameInNamespace() + "/" + encName);
         namingException.setRootCause(e);
         throw namingException;
      }
   }

   public static void loadMethodInjectors(HashSet<String> visitedMethods, EJBContainer container, Class clazz, HashMap<AccessibleObject, Injector> injectors, boolean isContainer) throws Exception
   {
      if (clazz == null || clazz.equals(Object.class))
      {
         return;
      }
      Method[] methods = clazz.getDeclaredMethods();
      for (int i = 0; i < methods.length; i++)
      {
         PersistenceUnit ref = (PersistenceUnit) InjectionUtil.getAnnotation(
                 PersistenceUnit.class, container, methods[i], isContainer);
         try
         {
            if (ref != null)
            {
               if (!Modifier.isPrivate(methods[i].getModifiers()))
               {
                  if (visitedMethods.contains(methods[i].getName())) continue;
                  else visitedMethods.add(methods[i].getName());
               }
               if (!methods[i].getName().startsWith("set"))
                  throw new RuntimeException("@PersistenceContext can only be used with a set method: " + methods[i]);
               if (methods[i].getParameterTypes().length != 1)
                  throw new RuntimeException("@PersistenceContext can only be used with a set method of one parameter: " + methods[i]);
               String encName = ref.name();
               if (encName == null || encName.equals(""))
               {
                  encName = InjectionUtil.getEncName(methods[i]);
               }
               else
               {
                  encName = "env/" + ref.name();
               }
               String unitName = container.getPuEncXmlEntries().get(encName);
               boolean bindingExists = false;
               if (unitName == null) unitName = ref.unitName();
               else bindingExists = true;
               if (!injectors.containsKey(methods[i]))
               {
                  if (methods[i].getParameterTypes()[0].getName().equals(EntityManagerFactory.class.getName()))
                  {
                     EntityManagerFactory factory = getEntityManagerFactory(unitName, container);
                     injectors.put(methods[i], new EntityManagerFactoryMethodInjector(methods[i], container, factory));
                     if (!bindingExists) bindEntityManagerFactoryToEnc(container, encName, factory);
                  }
                  else
                  {
                     SessionFactory factory = getSessionFactory(unitName, container);
                     injectors.put(methods[i], new EntityManagerFactoryMethodInjector(methods[i], container, factory));
                     if (!bindingExists) bindEntityManagerFactoryToEnc(container, encName, factory);
                  }
               }
            }
         }
         catch (Exception ex)
         {
            throw new RuntimeException("failed in processing injection annotation on method: " + methods[i].toString(), ex);
         }
      }
      loadMethodInjectors(visitedMethods, container, clazz.getSuperclass(), injectors, isContainer);
   }

   public static void loadMethodDependencies(HashSet<String> visitedMethods, EJBContainer container, Class clazz, boolean isContainer)
   {
      if (clazz == null || clazz.equals(Object.class))
      {
         return;
      }
      Method[] methods = clazz.getDeclaredMethods();
      for (int i = 0; i < methods.length; i++)
      {
         PersistenceUnit ref = (PersistenceUnit) InjectionUtil.getAnnotation(
                 PersistenceUnit.class, container, methods[i], isContainer);
         if (ref != null)
         {
            if (!Modifier.isPrivate(methods[i].getModifiers()))
            {
               if (visitedMethods.contains(methods[i].getName())) continue;
               else visitedMethods.add(methods[i].getName());
            }
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(methods[i]);
            }
            else
            {
               encName = "env/" + encName;
            }
            if (container.getPuEncXmlEntries().containsKey(encName)) continue;
            try
            {
               addPUDependency(ref.unitName(), container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Method " + methods[i].toString() + " @PersistenceUnit in error: " + e.getMessage());
            }
         }
      }
      loadMethodDependencies(visitedMethods, container, clazz.getSuperclass(), isContainer);
   }

   private static void loadFieldDependencies(Class clazz, EJBContainer container, boolean isContainer)
   {
      if (clazz == null || clazz.equals(Object.class)) return;
      loadFieldDependencies(clazz.getSuperclass(), container, isContainer);
      Field[] fields = clazz.getDeclaredFields();
      for (int i = 0; i < fields.length; i++)
      {
         PersistenceUnit ref = (PersistenceUnit) InjectionUtil.getAnnotation(
                 PersistenceUnit.class, container, fields[i], isContainer);
         if (ref != null)
         {
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(fields[i]);
            }
            else
            {
               encName = "env/" + encName;
            }
            if (container.getPuEncXmlEntries().containsKey(encName)) continue;
            try
            {
               addPUDependency(ref.unitName(), container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Field " + fields[i].toString() + " @PersistenceUnit in error: " + e.getMessage());
            }
         }
      }
   }

   private static void loadFieldInjectors(Class clazz, EJBContainer container, HashMap<AccessibleObject, Injector> injectors, boolean isContainer) throws Exception
   {
      if (clazz == null || clazz.equals(Object.class)) return;
      loadFieldInjectors(clazz.getSuperclass(), container, injectors, isContainer);
      Field[] fields = clazz.getDeclaredFields();
      for (int i = 0; i < fields.length; i++)
      {
         try
         {
            PersistenceUnit ref = (PersistenceUnit) InjectionUtil.getAnnotation(
                    PersistenceUnit.class, container, fields[i], isContainer);
            if (ref != null)
            {
               fields[i].setAccessible(true);
               String encName = ref.name();
               if (encName == null || encName.equals(""))
               {
                  encName = InjectionUtil.getEncName(fields[i]);
               }
               else
               {
                  encName = "env/" + encName;
               }
               String unitName = container.getPuEncXmlEntries().get(encName);
               boolean bindingExists = false;
               if (unitName == null) unitName = ref.unitName();
               else bindingExists = true;
               if (!injectors.containsKey(fields[i]))
               {
                  if (fields[i].getType().getName().equals(EntityManagerFactory.class.getName()))
                  {
                     EntityManagerFactory factory = getEntityManagerFactory(unitName, container);
                     injectors.put(fields[i], new EntityManagerFactoryFieldInjector(fields[i], container, factory));
                     if (!bindingExists) bindEntityManagerFactoryToEnc(container, encName, factory);
                  }
                  else
                  {
                     SessionFactory factory = getSessionFactory(unitName, container);
                     injectors.put(fields[i], new EntityManagerFactoryFieldInjector(fields[i], container, factory));
                     if (!bindingExists) bindEntityManagerFactoryToEnc(container, encName, factory);
                  }
               }
            }
         }
         catch (Exception ex)
         {
            throw new RuntimeException("failed in processing injection annotation on field: " + fields[i].toString(), ex);
         }
      }
   }
}
