/*
 * JBoss, Home of Professional Open Source
 * Copyright 2005, JBoss Inc., and individual contributors as indicated
 * by the @authors tag. See the copyright.txt in the distribution for a
 * full listing of individual contributors.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */
package org.jboss.ejb3.injection;

import org.jboss.ejb3.EJBContainer;
import org.jboss.ejb3.dd.Injectable;
import org.jboss.ejb3.dd.PersistenceContextRef;
import org.jboss.ejb3.entity.*;
import org.jboss.ejb3.interceptor.InterceptorInjector;
import org.jboss.logging.Logger;
import org.jboss.naming.Util;

import javax.naming.NameNotFoundException;
import javax.naming.NamingException;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.PersistenceContextType;
import javax.persistence.PersistenceContexts;
import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

/**
 * Searches bean class for all
 *
 * @author <a href="mailto:bill@jboss.org">Bill Burke</a>
 * @version $Revision$
 * @Inject and create Injectors
 */
public class PersistenceContextHandler
{
   private static final Logger log = Logger
           .getLogger(PersistenceContextHandler.class);

   private static void loadXmlDependencies(List<PersistenceContextRef> refs,
                                           EJBContainer container)
   {
      for (PersistenceContextRef ref : refs)
      {
         container.getPuEncXmlEntries().put("env/" + ref.getRefName(),
                 ref.getUnitName());
         try
         {
            PersistenceUnitHandler
                    .addPUDependency(ref.getUnitName(), container);
         }
         catch (NameNotFoundException e)
         {
            throw new RuntimeException("Illegal <persistence-context-ref> of "
                    + ref.getRefName() + " :" + e.getMessage());

         }
      }
   }

   private static void loadXml(List<PersistenceContextRef> refs,
                               EJBContainer container, Class clazz,
                               HashMap<AccessibleObject, Injector> injectors)
   {
      for (PersistenceContextRef ref : refs)
      {
         String encName = "env/" + ref.getRefName();
         // add it to list of
         String error = "unable to load <persistence-context-ref> for unitName: "
                 + ref.getUnitName() + " <ref-name>: " + ref.getRefName();
         PersistenceContextType type = ref.getPersistenceContextType();
         String unitName = ref.getUnitName();
         Class injectionType = null;
         if (ref.getInjectionTarget() != null)
         {
            AccessibleObject ao = InjectionUtil.findInjectionTarget(clazz, ref
                    .getInjectionTarget());
            if (ao instanceof Field)
            {
               injectionType = ((Field) ao).getType();
               injectors.put(ao, new JndiFieldInjector((Field) ao, encName,
                       container.getEnc()));
            }
            else
            {
               injectionType = ((Method) ao).getParameterTypes()[0];
               injectors.put(ao,
                       new JndiMethodInjector((java.lang.reflect.Method) ao,
                               encName, container.getEnc()));
            }
         }
         bindPersistenceContext(container, unitName, error, type, encName,
                 injectionType);
      }
   }

   private static void bindPersistenceContext(EJBContainer container,
                                              String unitName, String error, PersistenceContextType type,
                                              String encName, Class injectionType)
   {
      ManagedEntityManagerFactory factory = null;
      try
      {
         factory = PersistenceUnitHandler.getManagedEntityManagerFactory(
                 container, unitName);
      }
      catch (NameNotFoundException e)
      {
         error += " " + e.getMessage();
      }
      if (factory == null)
      {
         throw new RuntimeException(error);
      }
      if (type == PersistenceContextType.EXTENDED)
      {
         container.getExtendedPCs().put(factory.getKernelName(),
                 new ExtendedPersistenceContextInjector(container, factory));
         Object extendedPc = null;
         if (injectionType == null
                 || injectionType.getName().equals(EntityManager.class.getName()))
         {
            extendedPc = new ExtendedEntityManager(factory.getKernelName());
         }
         else
         {
            extendedPc = new ExtendedHibernateSession(factory.getKernelName());
         }
         try
         {
            Util.bind(container.getEnc(), encName, extendedPc);
         }
         catch (NamingException e)
         {
            throw new RuntimeException(error, e);
         }
      }
      else
      {
         Object entityManager = null;
         if (injectionType == null
                 || injectionType.getName().equals(EntityManager.class.getName()))
         {
            entityManager = new TransactionScopedEntityManager(factory);
         }
         else
         {
            entityManager = new TransactionScopedHibernateSession(factory);
         }
         try
         {
            Util.bind(container.getEnc(), encName, entityManager);
         }
         catch (NamingException e)
         {
            throw new RuntimeException(error, e);
         }
      }
   }

   private static void loadPersistenceContextClassAnnotations(
           EJBContainer container, Class clazz, boolean isContainer)
   {
      PersistenceContexts resources = (PersistenceContexts) InjectionUtil
              .getAnnotation(PersistenceContexts.class, (EJBContainer) container,
                      clazz, isContainer);
      if (resources != null)
      {
         for (PersistenceContext ref : resources.value())
         {
            loadPersistenceContextClassAnnotation(ref, container);
         }
      }
      PersistenceContext pc = (PersistenceContext) InjectionUtil.getAnnotation(
              PersistenceContext.class, (EJBContainer) container, clazz,
              isContainer);
      if (pc != null)
      {
         loadPersistenceContextClassAnnotation(pc, container);
      }

   }

   private static void loadPersistenceContextClassAnnotation(
           PersistenceContext ref, EJBContainer container)
   {
      String encName = ref.name();
      if (encName == null || encName.equals(""))
      {
         throw new RuntimeException(
                 "JBoss requires name() for class level @PersistenceContext");
      }
      encName = "env/" + ref.name();
      if (container.getPuEncXmlEntries().containsKey(encName))
         return;

      String error = "Unable to load class-level @PersistenceContext("
              + ref.unitName() + ") on " + container.getObjectName();
      bindPersistenceContext(container, ref.unitName(), error, ref.type(),
              encName, null);
   }

   private static void loadClassPUDependencies(EJBContainer container,
                                               Class clazz, boolean isContainer)
   {
      PersistenceContexts resources = (PersistenceContexts) InjectionUtil
              .getAnnotation(PersistenceContexts.class, (EJBContainer) container,
                      clazz, isContainer);
      if (resources != null)
      {
         for (PersistenceContext ref : resources.value())
         {
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               throw new RuntimeException(
                       "JBoss requires name() for class level @PersistenceContext");
            }
            if (container.getPuEncXmlEntries().containsKey("env/" + encName))
               continue;
            try
            {
               PersistenceUnitHandler.addPUDependency(ref.unitName(),
                       (EJBContainer) container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Invalid " + clazz.getName()
                       + " annotation @PersistenceContext: " + e.getMessage());
            }
         }
      }
      PersistenceContext pu = (PersistenceContext) InjectionUtil.getAnnotation(
              PersistenceContext.class, (EJBContainer) container, clazz,
              isContainer);
      if (pu != null)
      {
         String encName = pu.name();
         if (encName == null || encName.equals(""))
         {
            throw new RuntimeException(
                    "JBoss requires name() for class level @PersistenceContext");
         }
         if (container.getPuEncXmlEntries().containsKey("env/" + encName))
            return;
         try
         {
            PersistenceUnitHandler.addPUDependency(pu.unitName(),
                    (EJBContainer) container);
         }
         catch (NameNotFoundException e)
         {
            throw new RuntimeException("Invalid " + clazz.getName()
                    + " annotation @PersistenceContext: " + e.getMessage());
         }
      }
   }

   public static void loadMethodDependencies(HashSet<String> visitedMethods,
                                             EJBContainer container, Class clazz, boolean isContainer)
   {
      if (clazz == null || clazz.equals(Object.class))
      {
         return;
      }
      Method[] methods = clazz.getDeclaredMethods();
      for (int i = 0; i < methods.length; i++)
      {
         PersistenceContext ref = (PersistenceContext) InjectionUtil
                 .getAnnotation(PersistenceContext.class, container, methods[i],
                         isContainer);
         if (ref != null)
         {
            if (!Modifier.isPrivate(methods[i].getModifiers()))
            {
               if (visitedMethods.contains(methods[i].getName()))
                  continue;
               else
                  visitedMethods.add(methods[i].getName());
            }
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(methods[i]);
            }
            else
            {
               encName = "env/" + encName;
            }
            if (container.getPuEncXmlEntries().containsKey(encName))
               continue;
            try
            {
               PersistenceUnitHandler
                       .addPUDependency(ref.unitName(), container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Method " + methods[i].toString()
                       + " @PersistenceContext in error: " + e.getMessage());
            }
         }
      }
      loadMethodDependencies(visitedMethods, container, clazz.getSuperclass(),
              isContainer);
   }

   private static void loadFieldDependencies(Class clazz,
                                             EJBContainer container, boolean isContainer)
   {
      if (clazz == null || clazz.equals(Object.class))
         return;
      loadFieldDependencies(clazz.getSuperclass(), container, isContainer);
      Field[] fields = clazz.getDeclaredFields();
      for (int i = 0; i < fields.length; i++)
      {
         PersistenceContext ref = (PersistenceContext) InjectionUtil
                 .getAnnotation(PersistenceContext.class, container, fields[i],
                         isContainer);
         if (ref != null)
         {
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(fields[i]);
            }
            else
            {
               encName = "env/" + encName;
            }
            if (container.getPuEncXmlEntries().containsKey(encName))
               continue;
            try
            {
               PersistenceUnitHandler
                       .addPUDependency(ref.unitName(), container);
            }
            catch (NameNotFoundException e)
            {
               throw new RuntimeException("Field " + fields[i].toString()
                       + " @PersistenceUnit in error: " + e.getMessage());
            }
         }
      }
   }

   public static void loadInjectors(EJBContainer ejb) throws Exception
   {
      Class clazz = ejb.getBeanClass();
      if (clazz == null)
         throw new RuntimeException("CLAZZ CANNOT BE NULL");
      loadInjectors(clazz, ejb.getXml(), ejb, ejb.getEncInjections(), true);
   }

   public static void loadInjectors(InterceptorInjector injector)
           throws Exception
   {
      loadInjectors(injector.getClazz(), injector.getXml(),
              (EJBContainer) injector.getContainer(),
              injector.getEncInjections(), false);
   }

   private static void loadInjectors(Class clazz, Injectable xml,
                                     EJBContainer container,
                                     HashMap<AccessibleObject, Injector> encInjections, boolean isContainer)
           throws Exception
   {
      if (xml != null)
      {
         loadXml(xml.getPersistenceContextRefs(), container, clazz,
                 encInjections);
      }

      HashSet<String> visitedMethods = new HashSet<String>();
      loadPersistenceContextClassAnnotations(container, clazz, isContainer);

      loadMethodInjectors(visitedMethods, container, clazz, encInjections,
              isContainer);
      loadFieldInjectors(clazz, container, encInjections, isContainer);
   }

   public static void loadDependencies(Injectable xml, EJBContainer container,
                                       Class clazz, boolean isContainer)
   {
      if (xml != null)
      {
         loadXmlDependencies(xml.getPersistenceContextRefs(), container);
      }
      loadClassPUDependencies(container, clazz, isContainer);
      HashSet<String> visitedMethods = new HashSet<String>();
      loadMethodDependencies(visitedMethods, container, clazz, isContainer);
      loadFieldDependencies(clazz, container, isContainer);

   }

   private static void loadMethodInjectors(HashSet<String> visitedMethods,
                                           EJBContainer container, Class clazz,
                                           HashMap<AccessibleObject, Injector> injectors, boolean isContainer)
           throws Exception
   {
      if (clazz == null || clazz.equals(Object.class))
      {
         return;
      }
      Method[] methods = clazz.getDeclaredMethods();
      for (int i = 0; i < methods.length; i++)
      {
         PersistenceContext ref = (PersistenceContext) InjectionUtil
                 .getAnnotation(PersistenceContext.class,
                         (EJBContainer) container, methods[i], isContainer);
         if (ref != null)
         {
            if (!Modifier.isPrivate(methods[i].getModifiers()))
            {
               if (visitedMethods.contains(methods[i].getName()))
                  continue;
               else
                  visitedMethods.add(methods[i].getName());
            }

            if (!methods[i].getName().startsWith("set"))
               throw new RuntimeException(
                       "@PersistenceContext can only be used with a set method: "
                               + methods[i]);
            if (methods[i].getParameterTypes().length != 1)
               throw new RuntimeException(
                       "@PersistenceContext can only be used with a set method of one parameter: "
                               + methods[i]);
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(methods[i]);
            }
            else
            {
               encName = "env/" + ref.name();
            }
            String error = "@PersistenceContext(name='" + encName
                    + "',unitName='" + ref.unitName() + "') on EJB: "
                    + container.getObjectName() + " failed to inject on method "
                    + methods[i].toString();
            boolean bindingExists = container.getPuEncXmlEntries().containsKey(
                    encName);
            if (!injectors.containsKey(methods[i]))
            {
               if (!bindingExists)
                  bindPersistenceContext(container, ref.unitName(), error, ref
                          .type(), encName, methods[i].getParameterTypes()[0]);
               injectors.put(methods[i], new JndiMethodInjector(methods[i],
                       encName, container.getEnc()));
            }
         }
      }
      loadMethodInjectors(visitedMethods, container, clazz.getSuperclass(),
              injectors, isContainer);
   }

   private static void loadFieldInjectors(Class clazz, EJBContainer container,
                                          HashMap<AccessibleObject, Injector> injectors, boolean isContainer)
           throws Exception
   {
      if (clazz == null || clazz.equals(Object.class))
         return;
      loadFieldInjectors(clazz.getSuperclass(), container, injectors,
              isContainer);
      Field[] fields = clazz.getDeclaredFields();
      for (int i = 0; i < fields.length; i++)
      {
         PersistenceContext ref = (PersistenceContext) InjectionUtil
                 .getAnnotation(PersistenceContext.class,
                         (EJBContainer) container, fields[i], isContainer);
         if (ref != null)
         {
            fields[i].setAccessible(true);
            String encName = ref.name();
            if (encName == null || encName.equals(""))
            {
               encName = InjectionUtil.getEncName(fields[i]);
            }
            else
            {
               encName = "env/" + ref.name();
            }
            String error = "@PersistenceContext(name='" + encName
                    + "',unitName='" + ref.unitName() + "') on EJB: "
                    + container.getObjectName() + " failed to inject on field "
                    + fields[i].toString();
            boolean bindingExists = container.getPuEncXmlEntries().containsKey(
                    encName);
            if (!injectors.containsKey(fields[i]))
            {
               if (!bindingExists)
                  bindPersistenceContext(container, ref.unitName(), error, ref
                          .type(), encName, fields[i].getType());
               injectors.put(fields[i], new JndiFieldInjector(fields[i],
                       encName, container.getEnc()));
            }
         }
      }
   }
}
