//
// Copyright (C) 2004 Gehriger Engineering.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
//------------------------------------------------------------------------------
#include "msiextract.h"
#include "resource.h"
#include "getopt.h"

//------------------------------------------------------------------------------
// Main program entry
//------------------------------------------------------------------------------
int __cdecl _tmain(int argc, _TCHAR* argv[], _TCHAR* envp[])
{
    try
    {
        // parse command line
        tstring msiFile;
        FileList fileList;
        if (!parseCommandLine(argc, argv, fileList, msiFile))
            return -1;

        // open msi db
        SmrtMsiHandle db;
        OK(MsiOpenDatabase(msiFile.c_str(), MSIDBOPEN_READONLY, &db));

        // get list of stored files
        StoredFiles storedFiles;
        getStoredFiles(db, storedFiles);

        // extract each file
        if (fileList.empty())
        {
            // extract everything
            for (StoredFiles::const_iterator it = storedFiles.begin(); it != storedFiles.end(); ++it)
            {
                extractStoredFile(db, it->first.c_str(), it->second.c_str());
            }
        }
        else
        {
            for (FileList::const_iterator it = fileList.begin(); it != fileList.end(); ++it)
            {
                StoredFiles::const_iterator itCab = storedFiles.find(*it);
                if (itCab == storedFiles.end())
                {
                    tcerr << _T("File '") << *it << _T("' not found!") << std::endl;
                    continue;
                }

                extractStoredFile(db, itCab->first.c_str(), itCab->second.c_str());
            }
        }
    }
    catch (_com_error e) 
    {
        if (e.Error() != E_FAIL) 
        {
            tcerr << e.ErrorMessage() << std::endl;
        }
        else 
        {
            tcerr << _T("Unknown error!") << std::endl;
        }

        return -1;
    }

    return 0;
}

//------------------------------------------------------------------------------
// obtain list of stored MSI files
//------------------------------------------------------------------------------
void getStoredFiles(MSIHANDLE db, StoredFiles& storedFiles)
{
    UINT result;
    SmrtMsiHandle record;
    SmrtMsiHandle view;

    // obtain list of cabinets and their last sequence number
    static const TCHAR* mediaQuery = _T("SELECT `LastSequence`, `Cabinet` FROM `Media`");
    typedef std::list<std::pair<int, tstring> > LastSequences;
    LastSequences lastSequences;
    OK(MsiDatabaseOpenView(db, mediaQuery, &view));
    OK(MsiViewExecute(view, NULL));

    while ((result = MsiViewFetch(view, &record)) != ERROR_NO_MORE_ITEMS)
    {
        OK(result);
        int lastSequence = MsiRecordGetInteger(record, 1);
        tstring cabinet = recordGetString(record, 2);

        // store
        lastSequences.push_back(std::make_pair(lastSequence, cabinet));
    }

    lastSequences.sort();

    // now list files
    static const TCHAR* fileQuery = _T("SELECT `File`, `Sequence` FROM `File` ORDER BY `Sequence`");
    OK(MsiDatabaseOpenView(db, fileQuery, &view));
    OK(MsiViewExecute(view, NULL));

    while ((result = MsiViewFetch(view, &record)) != ERROR_NO_MORE_ITEMS)
    {
        OK(result);
        tstring file = recordGetString(record, 1);
        int sequence = MsiRecordGetInteger(record, 2);

        // lookup containing cab
        tstring cabinet;
        for (LastSequences::const_iterator it = lastSequences.begin(); it != lastSequences.end(); ++it)
        {
            if (it->first >= sequence)
            {
                cabinet = it->second;
                break;
            }
        }

        storedFiles.insert(std::make_pair(file, cabinet));
    }
}

//------------------------------------------------------------------------------
// extract stored file
//------------------------------------------------------------------------------
void extractStoredFile(MSIHANDLE db, const TCHAR* file, const TCHAR* cab)
{
    // if internal cab, extract to temporary file
    if (cab[0] != _T('#'))
    {
        tcerr << _T("Error: extraction from external cabinet files not (yet) supported: ") << cab << std::endl;
        return;
    }

    tstring query = _T("SELECT `Data` FROM `_Streams` WHERE `Name`='") + tstring(cab+1) + _T("'");
    SmrtMsiHandle view;
    OK(MsiDatabaseOpenView(db, query.c_str(), &view));
    OK(MsiViewExecute(view, NULL));

    // fetch the stream
    SmrtMsiHandle record;
    if (MsiViewFetch(view, &record) == ERROR_NO_MORE_ITEMS)
    {
        tcerr << _T("Warning: missing embedded stream ") << cab << std::endl;
        return;
    }

    // read binary data
    DWORD cbSize = MsiRecordDataSize(record, 1);
    std::vector<char> bufBinary(cbSize+1);
    OK(MsiRecordReadStream(record, 1, &bufBinary[0], &cbSize));

    // create temporary file name
    _TCHAR tmpDir[MAX_PATH];
    _TCHAR tmpFile[MAX_PATH];
    GetTempPath(MAX_PATH, tmpDir);
    GetTempFileName(tmpDir, _T("cab"), 0, tmpFile);

    // get current directory
    _TCHAR currentDir[MAX_PATH+1];
    GetCurrentDirectory(MAX_PATH, currentDir);

    // create file
    SmrtFileHandle hFile(CreateFile(tmpFile, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL));
    if (hFile.isNull()) 
        _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));

    // write file
    DWORD dwWritten = 0;
    WriteFile(hFile, (LPCVOID)&bufBinary[0], cbSize, &dwWritten, NULL);
    CloseHandle(hFile.release());
    
    // extract file from cab
    CabExtract cabex(tmpFile);
    if (!cabex.extractTo(currentDir, extractCallback, (void*)file))
        _com_issue_error(HRESULT_FROM_WIN32(GetLastError()));

    // delete temporary cab
    DeleteFile(tmpFile);
}

//------------------------------------------------------------------------------
// Cab extraction callback
//------------------------------------------------------------------------------
bool __stdcall extractCallback(void* pv, bool extracted, LPCTSTR entry, size_t size)
{
    const TCHAR* file = reinterpret_cast<const TCHAR*>(pv);
    
    if (!extracted)
        return _tcsicmp(file, entry) == 0; // only extract if requested
    else
        return true;
}

//------------------------------------------------------------------------------
// Print short usage message
//------------------------------------------------------------------------------
void printUsage()
{
    tcerr << _T("\nUsage: ") << std::endl;
    tcerr << _T("msiextract [-x FILENAME] file") << std::endl;
    tcerr << _T(" -x --extract                  extract only specified file") << std::endl;
    tcerr << _T("                               (repeat option to extract multiple files)") << std::endl;
    tcerr << std::endl;
}

//------------------------------------------------------------------------------
// Parse command line
//------------------------------------------------------------------------------
bool parseCommandLine(int argc, _TCHAR* argv[], FileList& fileList, tstring& msiFile)
{
    tcerr << _T("msiextract 1.0");
#ifdef _UNICODE
    tcerr << _T(" (Unicode)");
#endif
    tcerr << _T(", Copyright (C) 2004 Gehriger Engineering.") << std::endl << std::endl;
    tcerr << _T("Msiextract comes with ABSOLUTELY NO WARRANTY.") << std::endl;
    tcerr << _T("This is free software, and you are welcome to redistribute it") << std::endl;
    tcerr << _T("under certain conditions; use the '-l' option for details.") << std::endl << std::endl;

    // short option string (option letters followed by a colon ':' require an argument)
    static const _TCHAR optstring[] = _T("lx:");

    // mapping of long to short arguments
    static const Option longopts[] = 
    {
        // general options
        { _T("license"),            no_argument,        NULL,   _T('l') },
        { _T("extract"),            required_argument,  NULL,   _T('x') },
        { NULL,                     0,                  NULL,   0       }
    };

    int c;
    int longIdx = 0;
    while ((c = getopt_long(argc, argv, optstring, longopts, &longIdx)) != -1) 
    {
        if (optarg != NULL && (*optarg == _T('-') || *optarg == _T('/')) ||
            optind == argc) 
        {
            optarg = NULL;
            --optind;
        }

        switch (c) 
        {
        case _T('x'):  // extract specified file
            if (optarg)
            {
                fileList.push_back(optarg);
            }
            break;

        case _T('l'):  // show license and exit
            std::cerr << loadTextResource(IDR_GPL);
            exit(0);

        case _T('?'):  // invalid argument
            printUsage();
            exit(2);
        }
    }

    if (optind != argc - 1) 
    {
        printUsage();
        exit(2);
    }

    if (argv[optind]) 
    {
        msiFile = argv[optind];
    }

    return !msiFile.empty();
}

//------------------------------------------------------------------------------
// Load text resource
//
// Parameters:
//
//  resourceId             - resource ID
//
// Returns:
//
//  pointer to text array
//------------------------------------------------------------------------------
LPCSTR loadTextResource(WORD resourceId)
{
    HRSRC hXMLTempl = FindResource(NULL, MAKEINTRESOURCE(resourceId), _T("TEXT"));

    if (hXMLTempl == NULL)
        return NULL;

    HGLOBAL hResXMLTempl = LoadResource(NULL, hXMLTempl);

    if (hResXMLTempl == NULL)
        return NULL;

    LPCSTR szXMLTempl = (LPCSTR)LockResource(hResXMLTempl);

    if (!szXMLTempl )
        return NULL;

    return szXMLTempl;
}

//------------------------------------------------------------------------------
// Return string record field
//------------------------------------------------------------------------------
tstring recordGetString(MSIHANDLE record, UINT col)
{
    static tcharvector buf(1);
    DWORD len;
    UINT res;

    while ((res = MsiRecordGetString(record, col, &buf[0], &(len = buf.size()))) == ERROR_MORE_DATA) 
    {
        buf.resize(len+1); 
    }
    OK(res);

    return &buf[0];
}