/*
*  Copyright (C) 2005 CyberLink Corp.
*
*  This program is free software; you can redistribute it and/or modify
*  it under the terms of the GNU General Public License as published by
*  the Free Software Foundation; either version 2 of the License, or
*  (at your option) any later version.
*
*  This program is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU General Public License for more details.
*
*  You should have received a copy of the GNU General Public License
*  along with this program; if not, write to the Free Software
*  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
                 
#include <stdio.h>
#include <ctype.h>
#include <windows.h>
#include <winioctl.h>
#include "zlib/zlib.h"

/* Interface */

int InstallMBR(char* drive, char* path, int boot_disk_no);
int CheckMBR(int boot_disk_no);
int RestoreMBR(int boot_disk_no);

enum InstallMBRErrorCode {
    INVALID_PATH=1,           // the input path has invalid characters
    COLLECT_FAILED,           // failed to collect blocklist
    PATCH_STAGE1_FAILED,      // failed to patch stage1
    BACKUP_FAILED,            // failed to backup MBR
    UPDATE_FAILED,            // failed to update MBR
    GET_DISK_PARTITION_FAILED,// failed to get disk/partition number
    GENERATE_MENU_FAILED,     // failed to generate menu.lst
    PATCH_STAGE2_FAILED,      // failed to patch stage2
    PATCH_INITRD_FAILED       // failed to patch initrd
};

enum CheckMBRErrorCode {
    READ_FAILED=1,
    MBR_MODIFIED
};    

enum RestoreMBRErrorCode {
    GET_PATH_FAILED=1,
    READ_ORIGINAL_FAILED,
    READ_MBR_FAILED,
    WRITE_MBR_FAILED
};

// internal
static int valid_char(char c);
static char *collect_blocklist(char *device, char* image2);
static int save_data(unsigned char *buf, size_t size, const char *pathname);
static int load_data(unsigned char *buf, size_t size, const char *pathname);
static int read_mbr(int drive_no, unsigned char buf[512]);
static int write_mbr(int drive_no, unsigned char buf[512]);
static int get_disk_part_no(const char* drive, int* disk_no, int* part_no);
static int generate_menu(int boot_disk_no, int disk_no, int part_no, const char* path, const char* pathname);
static int patch_initrd(int disk_no, int part_no, const char* pathname_in, const char* pathname_out, const char* path_image);
static void make_file_writable(const char* pathname);
static void make_file_protected(const char* pathname);
static int read_registry_path(char *path, int len);
static int write_registry_path(const char* path);
static void update_mbr(char dst[512], const char src[512]);
static BOOL VerifyFile(LPCTSTR szFile);
static void NormalizeFiles();

// external 
char* ntfs_blocklist(char* device, char* path);
char* fat_blocklist(char* device, char* path);
int stage1(char* filename, LONG sector, BYTE lba, BYTE drive);
int stage2(char* filename, char* blocklist, char* menu);
extern int win32_fd;
  
/* Driver */
int main(int argc, char **argv)
{
    if(argc != 2) {
        printf("\
Usage:\n\
    ion_install.exe path  Install MBR from installation directory path\n\
    ion_install.exe /c    Check whether MBR has been modified\n\
    ion_install.exe /r    Restore original MBR\n\
");
        return -1;
    }

    const int BOOT_DRIVE = 0;
    char src_path[MAX_PATH];
    if(strcmp(argv[1],"/r")==0)
    {
        // make sure the mbr is actually written by us
        if(CheckMBR(BOOT_DRIVE)) return 0;
        if(RestoreMBR(BOOT_DRIVE)) goto restore_failed;
        MessageBox(NULL, "Restore original MBR successfully",
                         "CyberLink PCM InstantOn", MB_OK);

        return 0;
    }    
    else if(strcmp(argv[1],"/c")==0)
    {
        NormalizeFiles();
        if(CheckMBR(BOOT_DRIVE) != MBR_MODIFIED) goto out;
        int ans = MessageBox(NULL, 
                  "Your MBR has been modified, reactivate PCM?",
                  "CyberLink PCM InstantOn",
                  MB_YESNOCANCEL);
        if(ans!=IDYES) goto out;
        if(read_registry_path(src_path, MAX_PATH)) goto path_failed;
    }
    else
        strcpy(src_path, argv[1]);

// install:

    char full[MAX_PATH];    
    if(_fullpath(full, src_path, MAX_PATH)==NULL) goto path_failed;
    char drive[_MAX_DRIVE];
    char dir[_MAX_DIR];
    char fname[_MAX_FNAME];
    char ext[_MAX_EXT];
    _splitpath(full, drive, dir, fname, ext);
    char path[MAX_PATH];
    sprintf(path, "%s%s%s", dir, fname, ext);
    char *p;
    for(p = path; *p; p++)
        if(*p=='\\') *p='/';
    int err = InstallMBR(drive, path, BOOT_DRIVE);
    if(err)
    {
        char msg[80];
        sprintf(msg, "Install MBR failed -- error code %d", err);
        MessageBox(NULL, msg, "CyberLink PCM InstantOn", MB_OK|MB_ICONEXCLAMATION);
        return -1;
    }
    else
    {
        MessageBox(NULL, "Install MBR successfully", 
                         "CyberLink PCM InstantOn", MB_OK|MB_ICONINFORMATION);
        return 0;
    }    
out:
    return 0;
path_failed:        
    MessageBox(NULL, "Failed to get installation path", 
                     "CyberLink PCM InstantOn", MB_OK);
    return -1;
restore_failed:
    MessageBox(NULL, "Failed to restore original MBR", 
                     "CyberLink PCM InstantOn", MB_OK);
    return -1;    
}

/*
    Install MBR on the specified boot disk (boot_disk_no),
    given the necessary files in (drive, path)
    e.g.
        drive = "D:"
        path = "/boot"
        boot_disk_no = 0
*/
int InstallMBR(char* drive, char* path_orig, int boot_disk_no)
{
    char path[MAX_PATH];
    GetShortPathName(path_orig, path, MAX_PATH);

    /*const char *p = path;
    while(*p) {
        if(!valid_char(*p++)) 
                return INVALID_PATH;
    }*/    

    // (1) collect the blocklist for stage2
    // The filename "safeboot.fs" avoids the file being defragmented by windows
    char pathname[MAX_PATH];
    sprintf(pathname, "%s/stage2/safeboot.fs", path);
    char *blocklist = collect_blocklist(drive, pathname);
    if(!blocklist){
        OutputDebugString("[Bazooka] Failed on collect_blocklist();\n");
        return COLLECT_FAILED;
    }
        

    int sector;
    if (sscanf(blocklist, "%d", &sector)==0) {
        OutputDebugString("[Bazooka] Failed on sscanf();\n");
        return COLLECT_FAILED;
    }


    // (2) get disk and partition number
    int disk_no, part_no;
    if(get_disk_part_no(drive, &disk_no, &part_no)) return GET_DISK_PARTITION_FAILED;

    // (3) patch stage1
    sprintf(pathname, "%s%s/stage1", drive, path);
    make_file_writable(pathname);
    const int force_lba = 0;
    if (stage1(pathname, sector, force_lba, 0x80+disk_no))
        return PATCH_STAGE1_FAILED;
        
    // (4) save path into registry
    sprintf(pathname, "%s%s", drive, path);
    write_registry_path(pathname);

    // (5) read MBR
    char mbr[512];
    if(read_mbr(boot_disk_no, mbr)) return BACKUP_FAILED;
    // backup only if it has not been saved or it has changed since we saved it
    if(CheckMBR(boot_disk_no))
    {        
        sprintf(pathname, "%s%s/bootsect.orig", drive, path);
        if(save_data(mbr, 512, pathname)) return BACKUP_FAILED;
    }
    
    // (6) update MBR and write back to stage1
    char stage1_orig[512];
    sprintf(pathname, "%s%s/stage1", drive, path);
    if(load_data(stage1_orig, 512, pathname)) return UPDATE_FAILED;
    update_mbr(mbr, stage1_orig);
    if(save_data(mbr, 512, pathname)) return UPDATE_FAILED;
    if(write_mbr(boot_disk_no, mbr) < 0) return UPDATE_FAILED;
    
    // (7) generate menu.lst
    sprintf(pathname, "%s%s/menu.lst", drive, path);
    if(generate_menu(boot_disk_no, disk_no, part_no, path, pathname))
        return GENERATE_MENU_FAILED;

    // (8) patch stage2
    char menu[MAX_PATH];
    sprintf(pathname, "%s%s/stage2/safeboot.fs", drive, path);
    sprintf(menu, "(hd%d,%d)%s/menu.lst", disk_no, part_no, path);
    make_file_writable(pathname);
    if(stage2(pathname, blocklist, menu))
        return PATCH_STAGE2_FAILED;

    // (9) patch and gzip initrd
    sprintf(pathname, "%s%s/wrd.gz", drive, path);
    char pathname2[MAX_PATH];
    sprintf(pathname2, "%s%s/initrd.gz", drive, path);
    char path_image[MAX_PATH];
    sprintf(path_image, "%s/image", path);
    if(patch_initrd(disk_no, part_no, pathname, pathname2, path_image))
        return PATCH_INITRD_FAILED;

    // (10) hide all those files
    NormalizeFiles();
    sprintf(pathname, "%s%s/stage1",drive,path);             make_file_protected(pathname);
    sprintf(pathname, "%s%s/stage2/safeboot.fs",drive,path); make_file_protected(pathname);
    sprintf(pathname, "%s%s/bootsect.orig",drive,path);      make_file_protected(pathname);
    sprintf(pathname, "%s%s/menu.lst",drive,path);           make_file_protected(pathname);

    return 0;
}    

/*
    Uncompress and Decrypt files which will be used outside Windows
*/
void NormalizeFiles()
{
    char path[MAX_PATH];
    if(read_registry_path(path, MAX_PATH)) return;

    char pathname[MAX_PATH];
    sprintf(pathname, "%s/stage2/safeboot.fs",path); VerifyFile(pathname);
    sprintf(pathname, "%s/bootsect.orig",path);      VerifyFile(pathname);
    sprintf(pathname, "%s/menu.lst",path);           VerifyFile(pathname);
    sprintf(pathname, "%s/bzImage",path);            VerifyFile(pathname);
    sprintf(pathname, "%s/initrd.gz",path);          VerifyFile(pathname);
    sprintf(pathname, "%s/image",path);              VerifyFile(pathname);
}    

/*
    Check whether current MBR on disk is the still the same as we installed
*/
int CheckMBR(int boot_disk_no)
{
    char mbr[512];
    if(read_mbr(boot_disk_no, mbr)) return READ_FAILED;
    // read path from registry
    char path[MAX_PATH];
    if(read_registry_path(path, MAX_PATH)) return READ_FAILED;
    char pathname[MAX_PATH];
    sprintf(pathname, "%s/stage1", path);
    char orig[512];
    if(load_data(orig, 512, pathname)) return READ_FAILED;
    // ignore variable part
    if(memcmp(mbr, orig, 440)!=0) return MBR_MODIFIED;
    return 0;
}

/*
    Restore MBR
*/
int RestoreMBR(int boot_disk_no)
{
    char path[MAX_PATH];
    if(read_registry_path(path, MAX_PATH)) return GET_PATH_FAILED;
    char pathname[MAX_PATH];
    sprintf(pathname, "%s/bootsect.orig", path);
    char orig[512], mbr[512];
    if(load_data(orig, 512, pathname)) return READ_ORIGINAL_FAILED;
    if(read_mbr(boot_disk_no, mbr)) return READ_MBR_FAILED;
    update_mbr(mbr, orig);
    if(write_mbr(boot_disk_no, mbr)) return WRITE_MBR_FAILED;
    return 0;
}

int valid_char(char c)
{
    if(c=='/') return 1;
    if(isalnum(c)) return 1;
    switch(c)
    {
    case '-': case '.': case '_': case '~':
        return 1;
    default:
        return 0;
    }    
}

char *collect_blocklist(char *device, char* image2)
{
    char szBuf[256];
    sprintf(szBuf, "[Bazooka] collect_blocklist(%s, %s)\n", device, image2);
    OutputDebugString(szBuf);
    
    char* blocklist = NULL;
    OutputDebugString("[Bazooka] ntfs_blocklist()\n");
    blocklist = ntfs_blocklist(device, image2);
    if(!blocklist) {
    	OutputDebugString("[Bazooka] fat_blocklist()\n");
        blocklist = fat_blocklist(device, image2);
    }
    return blocklist;
}

int save_data(unsigned char *buf, size_t size, const char *pathname)
{
    make_file_writable(pathname);
    FILE *fp = fopen(pathname, "wb");
    if(fp==NULL) return -1;
    if(fwrite(buf, size, 1, fp)!=1)
    {
        fclose(fp);
        return -1;
    }
    else
    {
        fclose(fp);
        return 0;
    }    
}

int load_data(unsigned char *buf, size_t size, const char *pathname)
{
    FILE *fp = fopen(pathname, "rb");
    if(fp==NULL) return -1;
    if(fread(buf, size, 1, fp)!=1)
    {
        fclose(fp);
        return -1;
    }
    else
    {
        fclose(fp);
        return 0;
    }    
}

int read_mbr(int drive_no, unsigned char buf[512])
{
    char filename[MAX_PATH];

    sprintf(filename, "\\\\.\\PhysicalDrive%d", drive_no);
    HANDLE handle = CreateFile(filename,
			       GENERIC_READ,
			       FILE_SHARE_READ | FILE_SHARE_WRITE,
			       NULL,
			       OPEN_EXISTING,
			       FILE_ATTRIBUTE_SYSTEM,
			       NULL);
    if(handle == INVALID_HANDLE_VALUE)
        return -1;
    
    DWORD bytes_read;
    ReadFile(handle, buf, 512, &bytes_read, NULL);
    CloseHandle(handle);
    if(bytes_read != 512)
        return -1;
    return 0;
}
   
int write_mbr(int drive_no, unsigned char buf[512])
{
    char filename[MAX_PATH];

    sprintf(filename, "\\\\.\\PhysicalDrive%d", drive_no);
    HANDLE handle = CreateFile(filename,
			       GENERIC_WRITE,
			       FILE_SHARE_READ | FILE_SHARE_WRITE,
			       NULL,
			       OPEN_EXISTING,
			       FILE_ATTRIBUTE_SYSTEM,
			       NULL);
    if(handle == INVALID_HANDLE_VALUE)
        return -1;
    
    DWORD bytes_written;
    WriteFile(handle, buf, 512, &bytes_written, NULL);
    FlushFileBuffers(handle);
    CloseHandle(handle);
    if(bytes_written != 512)
        return -1;
    return 0;
}

int match(PARTITION_INFORMATION* p1, PARTITION_INFORMATION *p2)
{
    if(p1->StartingOffset.QuadPart   != p2->StartingOffset.QuadPart) return 0;
    if(p1->PartitionLength.QuadPart  != p2->PartitionLength.QuadPart) return 0;
    if(p1->HiddenSectors             != p2->HiddenSectors) return 0;
    if(p1->PartitionNumber           != p2->PartitionNumber) return 0;
    if(p1->PartitionType             != p2->PartitionType) return 0;
    if(p1->BootIndicator             != p2->BootIndicator) return 0;
    if(p1->RecognizedPartition       != p2->RecognizedPartition) return 0;
    if(p1->RewritePartition          != p2->RewritePartition) return 0;
    return 1;
}    

int get_disk_part_no(const char* drive, int* disk_no, int* part_no)
{
  char filename[30];
  sprintf(filename,"\\\\.\\%s", drive);

  HANDLE handle = CreateFile(filename,
			       GENERIC_READ,
			       FILE_SHARE_READ|FILE_SHARE_WRITE,
			       NULL,
			       OPEN_EXISTING,
			       0,
			       NULL);
  if (handle == INVALID_HANDLE_VALUE)
      return -1;

  PARTITION_INFORMATION part_info;
  DWORD bytes_returned;
  BOOL rv;
  rv = DeviceIoControl(
         handle,
		 IOCTL_DISK_GET_PARTITION_INFO,
  		 NULL, 0,
		 &part_info, sizeof(part_info),
	     &bytes_returned,
         NULL);
  if(rv == 0)
  {
    CloseHandle(handle);
    return -1;
  }

  char buffer[4096];
  DRIVE_LAYOUT_INFORMATION* layout = (DRIVE_LAYOUT_INFORMATION*)buffer;
  rv = DeviceIoControl(
         handle,
		 IOCTL_DISK_GET_DRIVE_LAYOUT,
		 NULL, 0,
		 &buffer, sizeof(buffer),
		 &bytes_returned,
		 NULL);

  CloseHandle(handle);
  if(rv==0) return -1;

  int i;
  for(i=0;i<16;i++)
  {
    sprintf(filename,"\\\\.\\PhysicalDrive%d", i);
    HANDLE handle = CreateFile(filename,
			       GENERIC_READ,
			       FILE_SHARE_READ|FILE_SHARE_WRITE,
			       NULL,
			       OPEN_EXISTING,
			       0,
			       NULL);
    if(handle != INVALID_HANDLE_VALUE)
    {
       char buffer1[4096];
       DRIVE_LAYOUT_INFORMATION *layout1 = (DRIVE_LAYOUT_INFORMATION*)buffer1;
       DeviceIoControl(handle,
			  IOCTL_DISK_GET_DRIVE_LAYOUT,
			  NULL, 0,
			  &buffer1, sizeof(buffer1),
			  &bytes_returned,
			  NULL);
       CloseHandle(handle);
       if(layout->PartitionCount == layout1->PartitionCount && 
          layout->Signature == layout1->Signature)
       {
              int j;
              for(j=0;j<layout->PartitionCount;j++)
                if(!match(&layout->PartitionEntry[j],&layout1->PartitionEntry[j])) break;

              if(j == layout->PartitionCount)
              {
                int cur_no=0;
                for(j=0;j<layout->PartitionCount;j++)
                {
                    if(j<4) cur_no=j;
                    else {
                        int type = layout->PartitionEntry[j].PartitionType;
                        if(type != 0 && type != 5 && type !=0x85 && type !=0x0f)
                            cur_no++;
                    }
                    if(match(&layout->PartitionEntry[j],&part_info))
                    {
                        *disk_no = i;
                        *part_no = cur_no;
                        return 0;
                    }    
                }    
              }    
       }     
    }    
  }
  return -1;
}

int generate_menu(int boot_disk_no, int disk_no, int part_no, const char* path, const char* pathname)
{    
    make_file_writable(pathname);
    FILE *fp = fopen(pathname,"wb");
    if(fp == NULL) return -1;
    char path_spec[MAX_PATH];    // something like "(hd0,0)/boot"
    sprintf(path_spec, "(hd%d,%d)%s", disk_no, part_no, path);
    const char* contents = "\
timeout 0\n\
default 0\n\
hiddenmenu\n\
\n\
title Original\n\
rootnoverify (hd%d,0)\n\
chainloader %s/bootsect.orig\n\
\n\
title PCM Linux\n\
kernel %s/bzImage root=/dev/loop0 rw\n\
initrd %s/initrd.gz";
    fprintf(fp, contents, boot_disk_no, path_spec, path_spec, path_spec);
    fclose(fp);
    return 0;
}

int patch_initrd(int disk_no, int part_no, const char* pathname_in, const char* pathname_out, const char *path_image)
{
    gzFile *fp = gzopen(pathname_in,"rb");
    if(fp == NULL) return -1;
    make_file_writable(pathname_out);
    gzFile fp2 = gzopen(pathname_out,"wb9");
    if(fp2 == NULL) { gzclose(fp); return -1; }

    const char* mark = "## CL-INITRD-TEMPLATE ##\n";
    const char* format = "\
#!/bin/sh\n\
mount -t proc proc proc\n\
mount /dev/hd%c%d /mnt\n\
losetup /dev/loop0 /mnt%s\n\
";
    char contents[512];
    sprintf(contents, format, 'a'+disk_no, part_no+1, path_image);
    int err = 0;
    char buf[512];
    while(1)
    {
        size_t got = gzread(fp, buf, 512);
        if(got == 0) break;
        if(got == 512 && memcmp(mark, buf, strlen(mark))==0)
                memcpy(buf, contents, strlen(contents));
        size_t sent = gzwrite(fp2, buf, got);
        if(sent != got) {
                err = -1;
                goto out;
        }    
    }
out:
    gzclose(fp);
    gzclose(fp2);
    return err;
}    
    
void make_file_writable(const char* pathname)
{
    SetFileAttributes(pathname, FILE_ATTRIBUTE_NORMAL);
}

void make_file_protected(const char* pathname)
{
    SetFileAttributes(pathname, FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM | FILE_ATTRIBUTE_READONLY);
}

HKEY hkey = HKEY_LOCAL_MACHINE;
const char *subkey = "SOFTWARE\\CyberLink\\InstantOn";

int read_registry_path(char *path, int len)
{
    HKEY key;
    if(RegOpenKey(hkey, subkey, &key))
        return -1;

    DWORD size = len;
    int err = 0;
    if(RegQueryValueEx(key, "Path", NULL, NULL, (BYTE*)path, &size))
        err = -1;

    RegCloseKey(key);
    return err;
}

int write_registry_path(const char* path)
{
    HKEY key;
    if(RegCreateKey(hkey, subkey, &key))
        return -1;

    int err = 0;
    if(RegSetValueEx(key, "Path", 0, REG_SZ, (BYTE*)path, strlen(path)+1))
        err = -1;

    RegCloseKey(key);
    return err;
}    

void update_mbr(char dst[512], const char src[512])
{
    memcpy(dst, src, 440);
}

static BOOL UnCompressFile(LPCTSTR szFile)
{
        BOOL bRet = FALSE;
        HANDLE hFile = CreateFile(szFile, GENERIC_READ|GENERIC_WRITE, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
        if (INVALID_HANDLE_VALUE != hFile) {                   
                USHORT uCompression = COMPRESSION_FORMAT_NONE;
                DWORD  dwReturn = 0;
                if (DeviceIoControl(hFile, FSCTL_SET_COMPRESSION, (LPVOID) &uCompression, sizeof(USHORT), NULL, 0, &dwReturn, NULL)) {
                        bRet = TRUE;
                }
                CloseHandle(hFile);
        }
        return bRet;
}

#define _FILE_ATTRIBUTE_ENCRYPTED       0x4000          // VC6's FILE_ATTRIBUTE_ENCRYPTED is wrong

WINBASEAPI BOOL WINAPI DecryptFileA(LPCTSTR, DWORD);
#define DecryptFile DecryptFileA

BOOL VerifyFile(LPCTSTR szFile)
{
        DWORD dwAttr = GetFileAttributes(szFile);
        if (-1 == dwAttr) return FALSE;         // unable to get file attributes

        if (FILE_ATTRIBUTE_COMPRESSED & dwAttr) {
                // file is compressed
                if (!UnCompressFile(szFile)) return FALSE;
        }

        if (_FILE_ATTRIBUTE_ENCRYPTED & dwAttr) {
                // file is encrypted
                if (!DecryptFile(szFile, 0)) return FALSE;
        }
        return TRUE;
}

