#!/usr/bin/env pike
/* -*- Pike -*- */

//#define DEBUG

#if !constant(Stdio.PROP_IPC)
#define NO_IPC
#endif

// Bugfix for some older versions of Pike..
string combine_path(string s, string ... rest)
{
  for(int e=0;e<sizeof(rest);e++)
  {
    if(sscanf(rest[e],"%*[a-zA-Z]:%*s")==2)
    {
      s=rest[e];
    }else{
      s=predef::combine_path(s,rest[e]);
    }
  }
  return s;
}

// Bugfix for some older versions of Pike..
#define BLOCK 65536
int cp(string from, string to)
{
  if(!Stdio.cp(from,to))
  {
    werror("Backup cp function in effect.\n");

    string data;
    object tmp=Stdio.File();
    if(!tmp->open(from,"r"))
    {
      werror(sprintf("Open %s for reading failed.\n",from));
      return 0;
    }
    function r=tmp->read;
    tmp=Stdio.File();
    if(!tmp->open(to,"wct"))
    {
      werror(sprintf("Open %s for writing failed.\n",to));
      return 0;
    }
    function w=tmp->write;
    do
    {
      data=r(BLOCK);
      if(!data)
      {
	werror("Read failed.\n");
	return 0;
      }
      if(w(data)!=strlen(data))
      {
	werror("Write failed.\n");
	return 0;
      }
    }while(strlen(data) == BLOCK);
  }
  return 1;
}

void monitor(object(Stdio.File) io, object proc)
{
  proc->wait();
  if(io)
  {
    io->close("rw");
    io->close();
    destruct(io);
  }
}

string opt_path(string p1, string p2)
{
  return  ( ( ((p1||"") + ";" + (p2||"")) / ";" ) - ({""}) ) * ";";
}

Thread.Mutex outlock=Thread.Mutex();

void myproxy(object pi, object io, int channel, Process.Process p)
{
  while(1)
  {
    string s=pi->read(1000,1); // Don't read more than 2<<24 :)
    if(!s || !strlen(s)) break;
#if 0
    // Disabled since nonblocking doesn't work on windows. Probably
    // not needed anymore anyway.
    if (sizeof(s) == 1) {
      // Stderr has a tendency to get one character at a time...
      if (mixed err = catch {
	  pi->set_nonblocking_keep_callbacks();
	  string extra;
	  while ((extra = pi->read(1000 - sizeof(s), 1)) && sizeof(extra)) {
	    s += extra;
	  }
	} ||
	catch {
	  pi->set_blocking_keep_callbacks();
	})
	werror ("Error reading nonblocking from %O: %s",
		pi, describe_error (err));
    }
#endif
    object key=outlock->lock();
#ifdef DEBUG
    werror ("Process %O: Channel %d: %O\n", p->pid(), channel, s);
#endif
    if (io->write(sprintf("%c%3c%s",channel,strlen(s),s)) != 4 + sizeof (s))
      werror ("Write on %O failed: %d\n", io, io->errno());
    destruct(key);
  }
}

void handle_incoming_connection(object(Stdio.File) io)
{
  object p;
  mapping env=copy_value(getenv());
  sscanf(io->read(4),"%4c",int args);
  array(string) cmd=allocate(args);
  for(int e=0;e<args;e++)
  {
    sscanf(io->read(4),"%4c",int len);
    cmd[e]=io->read(len);
  }

  object pi=Stdio.File();
  object pe=Stdio.File();
#ifdef NO_IPC
  object p2=pi->pipe();
  object pe2=pe->pipe();
#else
  object p2=pi->pipe(Stdio.PROP_IPC);
  object pe2=pe->pipe(Stdio.PROP_IPC);
#endif
  string dir=cmd[0];
  cmd=cmd[1..];

  while(sscanf(cmd[0],"%s=%s",string key, string val))
  {
    // Magic
    if(!env[key])
    {
      if(env[lower_case(key)])
	key=lower_case(key);
      else if(env[upper_case(key)])
	key=upper_case(key);
      else
      {
	foreach(indices(env), string x)
	  {
	    if(lower_case(x) == lower_case(key))
	    {
	      key=x;
	      break;
	    }
	  }
      }
    }
    if(val[0]==';')
    {
      env[key]=opt_path(env[key], val);
    }
    else if(val[-1]==';')
    {
      env[key]=opt_path(val, env[key]);
    }
    else
    {
      env[key]=val;
    }
#ifdef DEBUG
    werror("%s = %s\n",key,env[key]);
#endif
    cmd=cmd[1..];
  }

  write("Doing %s in %s\n",cmd*" ",dir);

  switch(lower_case(cmd[0]))
  {
    case "mkdir":
    {
      string dir=combine_path(combine_path(getcwd(),dir),cmd[1]);
      int ret;
      if(Stdio.file_size(dir)!=-2)
	ret=mkdir(dir);
      else
	ret=1;
      if(!ret)
      {
	string x=sprintf("MKDIR %s failed, errno=%d\n",dir,errno());
	io->write(sprintf("%4c%s",strlen(x),x));
      }
      io->write(sprintf("%4c",0));
      io->write(sprintf("%4c",!ret));
      break;
    }

    case "copy":
    {
      string from=combine_path(combine_path(getcwd(),dir),cmd[1]);
      string to=combine_path(combine_path(getcwd(),dir),cmd[2]);

      if(mixed stat=file_stat(to))
      {
	if(stat[1]==-2)
	{
	  to=combine_path(to,basename(cmd[1]));
	}
      }


      int ret=cp(from,to);
      if(!ret)
      {
	string x=sprintf("Errno is %d\n"
			 "CWD=%s\n"
			 "from=%s\n"
			 "to=%s\n"
			 "dir=%s (%s)\n",
			 errno(),
			 getcwd(),
			 from,
			 to,
			 dir, combine_path(getcwd(),dir));
	io->write(sprintf("%4c%s",strlen(x),x));
      }
      io->write(sprintf("%4c",0));
      io->write(sprintf("%4c",!ret));
      break;
    }
    case "getenv":
    {
      string s;
      if(sizeof(cmd)<2)
      {
	s="";
	foreach(indices(env), string x)
	  s+=sprintf("%s=%s\n",x,env[x]);
      }else{
	s=(env[cmd[1]] || "")+"\n";
      }
      io->write(sprintf("%4c%s",strlen(s),s));
      io->write(sprintf("%4c",0));
      io->write(sprintf("%4c",0));
      break;
    }

    case "sprshd-ping": {
      // This doesn't require that the paths are mapped.
      string s = "sprshd is alive.\n";
      io->write(sprintf("%4c%s", sizeof (s), s));
      io->write(sprintf("%4c",0));
      io->write(sprintf("%4c",0));
      break;
    }

    default:
#ifdef WINE
      void my_proxy(Stdio.File from, Stdio.File to)
 	{	
	  while(string s=from->read(128,1))
	    if(to->write(s)!=strlen(s))
	      return;
	  if(p) p->kill(9); /* DIE! */
	}

    {
      werror("Proxying.....\n");
      object p3=Stdio.File();
#ifdef NO_IPC
      object p4=p3->pipe();
#else
      object p4=p3->pipe(Stdio.PROP_IPC);
#endif
      thread_create(my_proxy,io,p4);
      io=p3;
    }
#endif

#if __VERSION__ >= 0.699999
   if(io->read_oob)
   {
#ifdef DEBUG
     write("Trapping OOB\n");
#endif
     thread_create(lambda() {
       while(1)
       {
	 string tmp=io->read_oob(1);
	 if(!tmp || !sizeof(tmp)) return;
	 werror("**Interrupt received, killing child.\n");
	 p->kill(9);
       }
     });
   }
#endif

      mixed err=catch {
	p=Process.create_process(cmd,
				 ([
#ifndef WINE
				   "stdin":io,
				   "stdout":p2,
				   "stderr":pe2,
#endif
				   "cwd":dir,
				   "env":env,
				   ]));
      };
      destruct(p2);
      destruct(pe2);
      if(!err)
      {
#ifdef DEBUG
	werror ("Forked process: %O\n", p->pid());
#endif

#ifdef NO_IPC
	thread_create(monitor,p2,p);
#endif
	
	object proxythread;
	if(env->__handles_stderr)
	{
	  proxythread=thread_create(myproxy,pe,io,2,p);
	  myproxy(pi,io,1,p);
	}else{
	  proxythread=thread_create(myproxy,pe,io,0,p);
	  myproxy(pi,io,0,p);
	}
#ifdef DEBUG
	werror ("Process %O: Waiting for stderr on %O to close.\n",
		p->pid(), pe);
#endif
	proxythread->wait();

	if (io->write(sprintf("%4c",0)) != 4)
	  werror ("Write on %O failed: %d\n", io, io->errno());

#ifdef DEBUG
	werror ("Waiting for process %O to exit.\n", p->pid());
#endif
	int code;
	if (mixed err = catch { code = p->wait(); }) {
	  werror ("Wait on %O failed: %s", p, describe_error (err));
	  io->write(sprintf("%4c", -1));
	} else {
#ifdef DEBUG
	  werror ("Process %O exited with code %d.\n", p->pid(), code);
#endif
	  if (io->write(sprintf("%4c",code)) != 4)
	    werror ("Write on %O failed: %d\n", io, io->errno());
	}
      }else{
	werror("create_process failed: %s", master()->describe_error (err));
	if(!dir || sizeof(dir)<2 ||
	   (dir[1] != ':' &&
	    !has_prefix (dir, "\\\\") &&
	    !has_prefix (dir, "//")))
	  werror("\n*** %O is probably not mapped on the windows machine! ***\n\n", dir);
	destruct(p2);
	io->write(sprintf("%4c",0));
	io->write(sprintf("%4c",69));
      }
  }
#ifdef WINE
  io->close("rw");
#else
  io->close("w");
#endif
  destruct(io);
}

void handle_connections(array(string) hosts)
{
  while(1)
  {
    if(object io=connection->accept())
    {
      int ok=0;
      sscanf(io->query_address(),"%s ",string ip);
      foreach(hosts, string host) ok+=glob(host, ip);
      if(!ok)
      {
	werror("Connection from %s denied!!\n",ip);
	destruct(io);
	continue;
      }
      thread_create(handle_incoming_connection,io);
    }else{
      werror("Accept failed "+errno()+"\n");
    }
  }
}

object connection = Stdio.Port();

int main(int argc, array(string) argv)
{
#ifdef WINE
  werror("Running in WINE mode.\n");
#endif
  if(argc < 3)
  {
    werror("Usage: sprshd <port> <hosts to accept connections from>\n");
    sleep(10); //Make sure the terminal isn't closed at once.
    exit(1);
  }
  if(!connection->bind((int)argv[1]))
  {
    werror("Failed to bind port: %s\n", strerror (connection->errno()));
    sleep(10);
    exit(1);
  }

  array(string) hosts=({});
  for(int e=2;e<sizeof(argv);e++)
  {
    if(sscanf(argv[e],"%*d.%*d")==2)
    {
      hosts+=({argv[e]});
      continue;
    }
    mixed tmp=gethostbyname(argv[e]);
    if(!tmp)
    {
      werror("Gethostbyname("+argv[e]+") failed.\n");
      exit(1);
    }
    hosts+=tmp[1];
  }

  write("Ready ("+version()+").\n");

#ifdef WINE
  thread_create(handle_connections,hosts);
  werror("main returning...\n");
  return -1;
#else
  handle_connections(hosts);
  return 0;
#endif
}
