/*
 * JBoss, the OpenSource J2EE webOS
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */
package org.jboss.webservice.server;

// $Id:WSDLRequestHandler.java,v 1.0, 2005-06-24 19:19:21Z, Robert Worsnop$

import org.jboss.deployment.DeploymentInfo;
import org.jboss.logging.Logger;
import org.jboss.webservice.metadata.WebserviceDescriptionMetaData;
import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import javax.wsdl.Definition;
import javax.wsdl.WSDLException;
import javax.wsdl.factory.WSDLFactory;
import javax.wsdl.xml.WSDLWriter;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.StringTokenizer;

/**
 * Handles the delivery of the WSDL and its included artifacts.
 * It rewrites the include URL's.
 *
 * http://www.jboss.org/index.html?module=bb&op=viewtopic&p=3871263#3871263
 *
 * For a discussion of this topic.
 *
 * @author Thomas.Diesler@jboss.org
 * @since 23-Mar-2005
 */
public class WSDLRequestHandler
{
   // provide logging
   private Logger log = Logger.getLogger(WSDLRequestHandler.class);

   private WebserviceDescriptionMetaData wsdMetaData;
   private DeploymentInfo di;

   public WSDLRequestHandler(WebserviceDescriptionMetaData wsdMetaData, DeploymentInfo di)
   {
      this.wsdMetaData = wsdMetaData;
      this.di = di;
   }

   /**
    * Get the WSDL resource for a given resource path
    * <p/>
    * Use path value of null to get the root document
    *
    * @param resourcePath The wsdl resource to get, can be null for the top level wsdl
    * @return A wsdl document, or null if it cannot be found
    */
   public Document getDocumentForPath(String requestURI, String resourcePath)
   {
      Document wsdlDoc = null;

      // Get the root wsdl
      if (resourcePath == null)
      {
         wsdlDoc = getWSDLDocument(wsdMetaData.getWsdlDefinition());
      }
      else
      {
         String wsdlFile = wsdMetaData.getWsdlFile();
         String rootDir = wsdlFile.substring(0, wsdlFile.lastIndexOf("/"));

         // Load the resource from the deployment
         URLClassLoader cl = di.localCl;

         String resource = rootDir + "/" + resourcePath;
         resource = canonicalize(resource);
         if (resource.startsWith("WEB-INF/wsdl/") == false && resource.startsWith("META-INF/wsdl/") == false)
            throw new SecurityException("Cannot access a resource below the wsdl root: " + resource);

         URL resURL = cl.findResource(resource);
         if (resURL == null)
            throw new IllegalStateException("Cannot obtain wsdl resource from: " + resource);

         try
         {
            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
            factory.setNamespaceAware(true);
            factory.setValidating(false);
            DocumentBuilder builder = factory.newDocumentBuilder();
            wsdlDoc = builder.parse(resURL.openStream());
         }
         catch (Exception e)
         {
            throw new IllegalArgumentException("Cannot parse wsdl resource: " + resURL);
         }
      }

      modifyImportLocations(requestURI, resourcePath, wsdlDoc.getDocumentElement());
      return wsdlDoc;
   }

   /**
    * Get the Document for a given wsdl definition
    */
   private Document getWSDLDocument(Definition wsdlDefinition)
   {
      try
      {
         WSDLFactory factory = WSDLFactory.newInstance();
         WSDLWriter wsdlWriter = factory.newWSDLWriter();
         return wsdlWriter.getDocument(wsdlDefinition);
      }
      catch (WSDLException e)
      {
         throw new RuntimeException(e);
      }
   }

   /**
    * Modify the location of wsdl and schema imports
    */
   private void modifyImportLocations(String requestURI, String resourcePath, Element element)
   {
      // map wsdl definition imports
      NodeList nlist = element.getChildNodes();
      for (int i = 0; i < nlist.getLength(); i++)
      {
         Node childNode = nlist.item(i);
         if (childNode.getNodeType() == Node.ELEMENT_NODE)
         {
            Element childElement = (Element)childNode;
            String nodeName = childElement.getLocalName();
            if ("import".equals(nodeName) || "include".equals(nodeName))
            {
               Attr locationAttr = childElement.getAttributeNode("schemaLocation");
               if (locationAttr == null)
                  locationAttr = childElement.getAttributeNode("location");

               if (locationAttr != null)
               {
                  String orgLocation = locationAttr.getNodeValue();
                  boolean isAbsolute = orgLocation.startsWith("http://") || orgLocation.startsWith("https://");
                  if (isAbsolute == false && orgLocation.startsWith(requestURI) == false)
                  {
                     String resource = orgLocation;

                     // This covers an include from within another include
                     // http://jira.jboss.com/jira/browse/JBWS-153
                     if (resourcePath != null && resourcePath.indexOf("/") > 0)
                     {
                        resource = resourcePath.substring(0, resourcePath.lastIndexOf("/") + 1);
                        resource = resource + orgLocation;
                     }

                     String newLocation = requestURI + "?wsdl&resource=" + resource;
                     locationAttr.setNodeValue(newLocation);

                     log.debug("Mapping import from '" + orgLocation + "' to '" + newLocation + "'");
                  }
               }
            }
            else
            {
               modifyImportLocations(requestURI, resourcePath, childElement);
            }
         }
      }
   }

   /**
    * Canonicalizes a path, removing .. and . references.
    */
   private String canonicalize(String path)
   {
      StringTokenizer tok = new StringTokenizer(path, "/");
      List parts = new ArrayList();
      while (tok.hasMoreTokens())
      {
         String t = tok.nextToken();
         if (".".equals(t))
         {
            // do nothing
         }
         else if ("..".equals(t) && parts.size() > 0)
         {
            // pop off the last one
            parts.remove(parts.size() - 1);
         }
         else
         {
            parts.add(t);
         }
      }

      StringBuffer ret = new StringBuffer();
      for (Iterator iter = parts.iterator(); iter.hasNext();)
      {
         ret.append((String)iter.next());
         if (iter.hasNext())
            ret.append('/');
      }
      return ret.toString();
   }

}
