# ex:ts=4
#
# socketfarm
#   -- idle accept preforking daemon framework
#
# $LinuxKorea: socketfarm.py,v 1.4 2001/12/10 12:21:34 perky Exp $
#
# Copyright 2001 Hye-Shik Chang. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without 
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
# 3. Neither the name of author nor the names of its contributors
#    may be used to endorse or promote products derived from this software
#    without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
# ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE. 
#
# code by Hye-Shik Chang <perky@linuxkorea.co.kr>
#

import sys

if sys.hexversion < 0x2010000:
	raise ImportError, "This module requires Python 2.1 or above"

import socket
import os, signal
import string, errno
import time, weakref
from select import select

# constants for arguments
( SC_CONTINUE,
  SC_SIGHUP,
  SC_SIGTERM,
  SC_SIGINT,
  SC_BOOM		) = range(5)
TERMSIGS = {
	signal.SIGHUP:		SC_SIGHUP,
	signal.SIGTERM:		SC_SIGTERM,
	signal.SIGINT:		SC_SIGINT
}
SOCK_DGRAM, SOCK_STREAM = socket.SOCK_DGRAM, socket.SOCK_STREAM

class Control:
	""" SocketFarm Connection Controller BaseClass """

	socktype = SOCK_STREAM

	def __init__(self, sock, addr, side):
		self.sock = sock
		self.addr = addr
		for k, d in side.items():
			self.__dict__[k] = d
	
	def start(self):
		if self.access(self.addr):
			self.sock.close()
		else:
			return self.run()
	
	def access(self, addr):
		# override me
		return None
		
	def run(self):
		# return 1, if you want not to run anymore.
		raise NotImplemented
	
	def process(self, data, addr):
		# use this instead of start, access, run methods for dgram sockets
		raise NotImplemented


class Service:
	""" SocketFarm Service Description Class """
	
	def __init__(self, address, sessctl=Control, side={}, **kwargs):
		if type(address) is type(()):
			family = socket.AF_INET
		else:
			family = socket.AF_UNIX
		
		self.address	= address
		self.sock		= socket.socket(family, sessctl.socktype)
		self.fileno		= self.sock.fileno()
		self.control	= sessctl
		self.socktype	= sessctl.socktype
		self.side = side.copy()
		self.side.update(kwargs)
		
		try:
			self.sock.setsockopt (
				socket.SOL_SOCKET, socket.SO_REUSEADDR,
				self.sock.getsockopt (socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1
			)
		except:
			print "setsockopt error"
			pass

		self.sock.bind(address)

		if self.socktype is SOCK_STREAM:
			self.sock.listen(128)
		elif self.socktype is SOCK_DGRAM:
			self.recvbuff = self.side.has_key('recvbuff') and self.side['recvbuff'] or 8192
	
	def close(self):
		self.sock.close()


class SocketFarm:
	""" SocketFarm, the multi-socket prefork daemon framework """

	def __init__(self, procnum=4, limit=(-1, -1, 0), boom=(3.0, 10)):
		self.services		= {}		# Service class instances
		self.children		= []		# pid of children
		self.maxprocess		= procnum	# reserved service daemons preforked
		self.limit			= limit		# limit service running time for 1 process
										# (tcp service, udp service, cpu time)
		self.stopcode		= 0			# return value of loop() method
		self.serial			= 0			# serial# of forked childs
		if boom:
			self.boom_threshold = boom[0]	# threshold that detects mass SIGCHLD
			self.boom_crpoint	= boom[1]	# critical point of mass SIGCHLD
		else:
			self.boom_threshold, self.boom_crpoint = 0, 0
		self.boom_cnt		= 0			# number of continuous SIGCHLD on mass SIGCHLD
		self.boom_ltm		= 0			# last sigchld time

		self.SIGCHLD_EVENT	= 0			# Internal SIGCHLD handler recursive blocker
		self.children_ignore = []		# processes that fails to remove from children

		self.oninit()
	
	def __del__(self):
		self.atexit()
		map(self.del_service, self.services.keys())
	
	def log(self, *args):
		# XXX: Temporary hehe :)
		sys.stdout.write(' '.join(args) + '\n')
	
	def add_service(self, ctl, addr, **kwargs):
		s = Service(addr, ctl, kwargs)
		self.services[s.fileno] = s
		return s.fileno
	
	def del_service(self, id):
		del self.services[id]
		# XXX: Is this method useful? -_-;

	def preserve_children(self):
		while self.maxprocess > len(self.children) and not self.stopcode:
			self.sanity_children()
			self.children.append(self.fork_child(self.serial))
			self.serial += 1
			time.sleep(0.1) # avoiding sigchld too fast
		return self.stopcode

	def SIGCHLD(self, signum, frame):
		if frame.f_code.co_name == 'SIGCHLD':
			self.SIGCHLD_EVENT = 1
			return

		try:
			while 1:
				pid = os.waitpid(0, os.WNOHANG)[0]
				if not pid:
					break
				try:
					self.children.remove(pid)
				except ValueError: # dead before children.append finishing
					self.children_ignore.append(pid)
		except OSError:
			pass

		if self.boom_threshold:
			tm = time.time()
			if tm - self.boom_ltm < self.boom_threshold:
				self.boom_cnt += 1
				if self.boom_cnt > self.boom_crpoint:
					self.boom_threshold = 0 # starting boom!
					self.log("BOOM DETECTED!")
					self.stop(SC_BOOM)
			else:
				self.boom_cnt = 0
				self.boom_ltm = tm
		
		if self.SIGCHLD_EVENT:
			self.SIGCHLD(signum, frame)
			self.SIGCHLD_EVENT = 0
	
	def signal_terminator(self, signum, frame):
		self.stopcode = TERMSIGS[signum]
		self.killall()
	
	def stop(self, stopcode, signum=signal.SIGTERM):
		self.stopcode = stopcode
		self.killall(signum)
	
	def sanity_children(self):
		for pid in self.children_ignore[:]:
			try:
				idx = self.children.index(pid)
			except ValueError:
				pass
			else:
				self.children_ignore.remove(pid)
				self.children.remove(pid)

	def killall(self, signum=signal.SIGTERM):
		self.sanity_children()
		for i in self.children[:]:
			os.kill(i, signum)

	def wait(self):
		self.sanity_children()
		while self.children:
			signal.pause()
	
	def fork_child(self, childn):
		pid = os.fork()
		if not pid:
			try:
				try:
					c = Child(self.services, childn, self.limit)
					self.childinit(c)
					c.loop()
					del c
				except KeyboardInterrupt:
					pass
				except socket.error, why:
					if why[0] not in [errno.EINTR]:
						import traceback
						traceback.print_exc()
				except:
					import traceback
					traceback.print_exc()
			finally:
				self.childexit()
				os._exit(0)
		return pid
	
	def oninit(self):
		# on initializing server
		pass

	def atexit(self):
		# on terminating server
		pass
	
	def childinit(self, childinst):
		# on initializing each child process
		pass
	
	def childexit(self):
		# on terminating each child process
		pass
	
	def loop(self):
		signal.signal(signal.SIGCHLD, self.SIGCHLD)
		signal.signal(signal.SIGHUP,  self.signal_terminator)
		signal.signal(signal.SIGTERM, self.signal_terminator)
		signal.signal(signal.SIGINT,  self.signal_terminator)
		
		while not self.stopcode:
			self.preserve_children() or signal.pause()
		
		signal.signal(signal.SIGINT,  signal.SIG_DFL)
		signal.signal(signal.SIGTERM, signal.SIG_DFL)
		signal.signal(signal.SIGHUP,  signal.SIG_DFL)
		
		self.wait()
		# in order to clean dying children
		signal.signal(signal.SIGCHLD, signal.SIG_DFL)
		
		return self.stopcode
	

class Child:
	
	def __init__(self, services, spid, limit):
		self.spid     = spid
		self.services = services
		self.socks    = services.keys()
		self.tcplimit, self.udplimit, self.cpulimit = limit
		for fno, s in services.items():
			if s.socktype is SOCK_DGRAM:
				s.control = s.control(s.sock, s.address, s.side)
	
	def loop(self):
		signal.signal(signal.SIGTERM, self.signal_terminator)
		
		while 1:
			for s in [self.services[s] for s in select(self.socks, [], [])[0]]:
				ran = 0
				if s.socktype is SOCK_STREAM:
					conn, addr = s.sock.accept()
					s.control(conn, addr, s.side).start() # blocks until ends
					if self.tcplimit > 0:
						self.tcplimit -= 1
				else:
					data, addr = s.sock.recvfrom(s.recvbuff)
					s.control.process(data, addr)
					if self.udplimit > 0:
						self.udplimit -= 1
			
			if not self.tcplimit or not self.udplimit or (self.cpulimit and time.clock() > self.cpulimit):
				break
		
		signal.signal(signal.SIGTERM, signal.SIG_DFL)
	
	def signal_terminator(self, signum, frame):
		self.tcplimit, self.udplimit = 0, 0


if __name__ == "__main__":
	
	import os, time
	
	class HelloTCP(Control):
		def run(self):
			self.sock.send("Hello!!")
			self.sock.close()
	
	class HelloUDP(Control):
		socktype = SOCK_DGRAM
		def process(self, data, addr):
			print os.getpid(), addr, ">>", data
			self.sock.sendto('Hello, ' + data, addr)
	
	s = SocketFarm(procnum=4, limit=(10, 10, 0), boom=(1.0, 10))
	s.add_service(HelloUDP, ('', 9999))
	s.add_service(HelloTCP, ('', 9999))
	s.loop()

