""" 
   Copyright (C) 2001 PimenTech SARL (http://www.pimentech.net)

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public License as
   published by the Free Software Foundation; either version 2 of the
   License, or (at your option) any later version.

   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; see the file COPYING.LIB.  If not,
   write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  
"""

from stack import *
from map import *
from set import *
from graph import *
from sax import *
from string import *
from table import *
from relation import *

class PgmlGraph(Graph):
	"the pgml graph builded from file.pgml"

	class PgmlHandler(CommonHandler):
		def __init__(self, name, graph):
			CommonHandler.__init__(self, name)

			self.graph = graph
			
			self.idSequence = graph.idSequence
			self.usePostgreSqlIsa = graph.usePostgreSqlIsa
			self.useIndices = graph.useIndices
			
			self.currentTable = None
			self.currentRelation = None

		def startElement(self, name, attrs):
			CommonHandler.startElement(self, name, attrs)
			
			valueOf = Map("ValueOf") # { 'name':'', 'table':'', 'type':'', 'isa':'', 'default':'' ... }
			for attr in attrs:
				valueOf[attr]=attrs[attr]
				
			if name == 'table':

				self.currentTable = self.graph.insert(Table(valueOf['name'],self.idSequence,self.usePostgreSqlIsa,self.useIndices)).object
				self.graph.stderr.write("table %s inserted\n" % str(self.currentTable))
				
				if valueOf['isa']:
					s_isa = valueOf['isa']
					table = self.graph.insert(Table(s_isa,self.idSequence,self.usePostgreSqlIsa,self.useIndices)).object
					self.currentTable.isa = table
					
					self.graph.stderr.write("table %s inherits from %s\n" % (str(self.currentTable),str(self.currentTable.isa)))
					
					relation = Relation('%s_isa_%s' % (self.currentTable,table))
					self.graph.insert_edge(self.currentTable,'1,1,1',relation)
					self.graph.insert_edge(relation,'1,1,0',table)
					
			elif name == 'relation':
				
				self.currentRelation = self.graph.insert(Relation(valueOf['name'])).object

				if valueOf['isa']:
					s_isa = valueOf['isa']
					table = self.graph.insert(Table(s_isa,self.idSequence,self.usePostgreSqlIsa,self.useIndices)).object
					self.currentRelation.isa = table
					
			elif name =='attribute':
				
				if self.parentTag == 'table':
					self.currentTable.insert_attribute(valueOf['name'],'type',valueOf['type'])
					self.currentTable.insert_attribute(valueOf['name'],'default',valueOf['default'])
					self.currentTable.insert_attribute(valueOf['name'],'constraints',valueOf['constraints'])
				else:
					self.currentRelation.insert_attribute(valueOf['name'],'type',valueOf['type'])
					self.currentRelation.insert_attribute(valueOf['name'],'default',valueOf['default'])
					self.currentRelation.insert_attribute(valueOf['name'],'constraints',valueOf['constraints'])					
					
			elif name =='participation':
				table = self.graph.insert(Table(valueOf['table'],self.idSequence,self.usePostgreSqlIsa,self.useIndices)).object
				self.graph.insert_edge(self.currentRelation,valueOf['type'],table)
				return
			
		def endElement(self, name):
			CommonHandler.endElement(self, name)
			if name == 'table':
				self.currentTable = None
			elif name == 'relation':
				self.currentRelation = None

	def __init__(self, name = 'PgmlGraph', debug = None, idSequence = None, usePostgreSqlIsa = None, usePostgreSqlConstraints = None,
				 refDefault = None, useIndices = None):
		
		Graph.__init__(self, name)
		if debug:
			self.stderr = stderr
		else:
			self.stderr = open('/dev/null','w')
			
		self.directed = 0 # the graph is not directed
		self.idSequence = idSequence # id sequence
		self.refDefault = refDefault # ref default
		self.usePostgreSqlIsa = usePostgreSqlIsa # usage ou non du mot cle isa postgreSql
		self.usePostgreSqlConstraints = usePostgreSqlConstraints
		self.useIndices = useIndices
		self.simplified = 0 # not simplified already

	def get_relations(self):
		"return a set of vertices containing relations"
		relations = Set('Relations')
		for v in self.values():
			if v.object.type == 'Relation':
				relations.insert(v)
		return relations

	def get_tables(self):
		"return a set of vertices containing tables"
		tables = Set('Tables')
		for v in self.values():
			if v.object.type == 'Table':
				tables.insert(v)
		return tables

	def _how_much_relations_between(self, vtable1, vtable2): # with refs in table1
		n = 0
		relations = Set('Relations')
		
		relations.insert_set(vtable1['0,1,1'])
		relations.insert_set(vtable1['1,1,1'])

		for vrelation in relations.values():
			for set in vrelation.values():
				if set.has_key(vtable2):
					isaRelationName = '%s_isa_%s' % (vtable1, vtable2)
					if vrelation.name == isaRelationName: # si c'est un isa relation
						if not self.usePostgreSqlIsa:
							n = n + 1
					else:
						n = n + 1
						
		if self.usePostgreSqlIsa and vtable1.object.isa: # alors il faut regarder les tables dont on herite
			return n + self._how_much_relations_between(self[vtable1.object.isa], vtable2)
		else:
			return n

	def _how_much_times_occurs_on_relations(self, vtable, attribute): # with refs in table
		n = 0
		relations = Set('Relations')
		relations.insert_set(vtable['0,1,1'])
		relations.insert_set(vtable['1,1,1'])
		for vrelation in relations.values():
			relation = vrelation.object
			if relation.has_key(attribute):
				n = n + 1
		return n
		
	def _need_table(self, vrelation):
		if vrelation.degre() > 2:
			self.stderr.write('need to create a table for relation %s (degre > 2)\n' % str(vrelation))
			return 1
		else:
			self.stderr.write('looking for tables with ref for relation %s (degre < 3)\n' % str(vrelation))
			for (label, vertices) in vrelation.items(): # for each edge outgoing from that relation
				min, max, k = split(label,',')
				if k == '1' and vertices.values(): # there is a table with ref
					self.stderr.write('found one\n')
					return 0
				
			self.stderr.write('no tables with ref for relation %s (degre < 3)\n' % str(vrelation))
			self.stderr.write('therefore need to create a table for relation %s\n' % str(vrelation))
			return 1

	def _simplify_ambiguous_relations(self):
		for vrelation in self.get_relations().values(): # for each relation

			if not self._need_table(vrelation):
				continue
			
			relation = vrelation.object
			self.stderr.write('deleting relation %s\n' % str(vrelation))
			del self[vrelation]
			
			self.stderr.write('inserting new table %s\n' % str(relation))			
			table_relation = self.insert(Table(str(relation),self.idSequence,self.usePostgreSqlIsa,self.useIndices)).object
			
			if relation.isa:
				self.stderr.write('adding isa on newly created table %s\n' % str(relation))
				table_relation.isa = relation.isa
				new_isa_relation = Relation('%s_isa_%s' % (table_relation,table_relation.isa))
				self.insert_edge(table_relation,'1,1,1',new_isa_relation)
				self.insert_edge(new_isa_relation,'1,1,0',table_relation.isa)
				
			self.stderr.write('adding attributes of %s on %s\n' % (str(relation),str(table_relation)))
			table_relation.insert_map(relation)

			vtable_relation = self.insert(table_relation)
			for (label, vertices) in vrelation.items():
				min, max, k = split(label,',')
				for vtable in vertices.values(): # for each participation on relation
					table = vtable.object
					
					self.stderr.write('deleting edge %s -- %s -- %s\n' % (str(table),label,str(relation)))
					del vtable[label][vrelation]

					self.stderr.write('creating relation %s_%s\n' % (str(table),str(table_relation)))
					new_relation = Relation('%s_%s' % (str(table),str(table_relation)))
					vnew_relation = self.insert(new_relation)
					self.stderr.write('inserting edge %s -- %s -- %s\n' % (new_relation,label,table))
					self.insert_edge(vnew_relation,label,vtable)
					
					if k == '1':
						self.stderr.write('inserting edge %s -- 1,1,0 -- %s\n' % (new_relation,table_relation))
						self.insert_edge(vnew_relation,'1,1,0',vtable_relation)
					else:
						self.stderr.write('inserting edge %s -- 1,1,1 -- %s\n' % (new_relation,table_relation))
						self.insert_edge(vnew_relation,'1,1,1',vtable_relation)
	
	def _put_refs(self): # only after simplify_ambiguous_relations() --> no more degre > 2
		
		for vrelation in self.get_relations().values(): # for each relation
			relation = vrelation.object
			
			for (label, vertices) in vrelation.items(): # for each edge outgoing from that relation
				min, max, k = split(label,',')
				
				if k == '1':
					self.stderr.write('putting refs on tables (x,x,1)\n')
					
					for vtable_with_ref in vertices.values(): # for each participation table of relation
						table_with_ref = vtable_with_ref.object
						
						for (label, vertices) in vrelation.items():
							min, max, k = split(label,',')
							
							for vtable in vertices.values(): # for each participation table of relation
								table = vtable.object
								
								if table == table_with_ref: # if it does not want ref_keys
									continue
								
								diff = '' # differentiation qd deux refs sur la meme table
								if self._how_much_relations_between(vtable_with_ref,vtable) > 1:
									diff = '_%s' % str(relation)

								# si ce n'est pas un isa ou c'est un isa et on utilises pas le mot cle postgreSql :
								isaRelationName = '%s_isa_%s' % (table_with_ref, table)
								if vrelation.name != isaRelationName or (vrelation.name == isaRelationName and not self.usePostgreSqlIsa):
										
									self.stderr.write('put ref%s_%s on %s\n' % (diff,str(table),str(table_with_ref)))
									table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),'type','int4')

									if self.refDefault: # default pour la reference (option -r) 
										table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),'default',self.refDefault)
											
									# pour ldb (independant de generation sur base relationnelle) :
									table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),'references',str(table))

									if self.usePostgreSqlConstraints:
										if min == '1':
											if table.isa and self.usePostgreSqlIsa:
												table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),'constraints','NOT NULL')
											else:
												table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),
																				'constraints','NOT NULL REFERENCES %s ON DELETE CASCADE' % str(table))
										else:
											if not table.isa or (table.isa and not self.usePostgreSqlIsa):
												table_with_ref.insert_attribute('ref%s_%s' % (diff,str(table)),
																				'constraints','REFERENCES %s ON DELETE SET NULL' % str(table))

										
								self.stderr.write('adding attributes of %s on %s\n' % (str(relation),str(table_with_ref)))
								for (name, properties) in relation.items():
									if self._how_much_times_occurs_on_relations(vtable_with_ref, name) > 1 or table_with_ref.has_key(name):
										table_with_ref["%s_%s" % (str(relation),name)] = properties
									else:
										table_with_ref[name] = properties
										
	def simplify(self):

		if self.simplified: return
		
		self._simplify_ambiguous_relations()
		self._put_refs()
		
		self.simplified = 1
		
	def read(self, filename):
		pf=saxexts.ParserFactory()
		p=pf.make_parser('xml.sax.drivers.drv_xmlproc')
		p.setDocumentHandler(self.PgmlHandler('doc_handler', self))
		p.setDTDHandler(self.PgmlHandler('dtd_handler', self))
		p.setErrorHandler(self.PgmlHandler('err_handler', self))
		p.setEntityResolver(self.PgmlHandler('ent_handler', self))
		p.parse(filename)

	def write_dot(self, output, options = None, tableNames = None):
		
		output.write('graph %s {\n' % self.name)
		
		if options:
			output.write('%s\n' % options)
		else:
			output.write('page="8.25,11.75";rotate=90;margin=0.2;\n')

		if tableNames:
			tableSet = Set()
			for tableName in split(tableNames,' '):
				tableSet.insert(tableName)
		
		for v in self.values():
			
			if v.object.type =='Table':
				if not tableNames or tableSet.has_key(v):
					v.object.write_dot(output)
				continue
			
			if tableNames:
				tableCount = 0 # counts in tableNames tables			
				for label in v.keys():
					for vdest in v[label].values():
						if tableSet.has_key(vdest):
							tableCount = tableCount + 1
				if tableCount > 1:
					v.object.write_dot(output)
				else:
					continue
			else:
				v.object.write_dot(output)	
				
			for label in v.keys():
				
				for vdest in v[label].values():
					
					if tableNames and not tableSet.has_key(vdest):
						continue
							
					output.write('%s -- %s [ label="%s" ];\n' % (str(v.object),str(vdest.object),label))
					
		output.write('}\n')
		return

	def write_pgml(self, output):
		self.clear_tags()
		output.write('<schema>\n')
		for vtable in self.get_tables().values():
			if not vtable.tag:
				self._write_pgml(vtable, output)
		
		for v in self.get_relations().values():
			relation = v.object
			output.write("<relation name='%s'>\n" % str(relation))
			relation._write_attributes_pgml(output)
			for (label, vertices) in v.items():
				for table in vertices.values():
					output.write("<participation table='%s' type='%s'/>\n" % (str(table),str(label)))
			output.write("</relation>\n")				
		output.write('</schema>\n')

	def _write_pgml(self, vtable, output):
		vtable.tag = 1
		for (label, svrel) in vtable.items(): # svrel : set of relations
			if self._label_dir(label) == '1':
				for vrel in svrel.values():
					for (label, svt) in vrel.items():
						if self._label_dir(label) == '0':
							for t in svt.values():
								if not t.tag:
									self._write_pgml(t, output)
		vtable.object.write_pgml(output)

	def _label_dir(self, label):
		return split(label, ',')[2]
	
	def write_sql(self, output):
		self.clear_tags()
		for vtable in self.get_tables().values():
			if not vtable.tag:
				self._write_sql(vtable, output)

	def _write_sql(self, vtable, output):
		vtable.tag = 1
		for (label, svrel) in vtable.items(): # svrel : set of relations
			if self._label_dir(label) == '1':
				for vrel in svrel.values():
					for (label, svt) in vrel.items():
						if self._label_dir(label) == '0':
							for t in svt.values():
								if not t.tag:
									self._write_sql(t, output)
		vtable.object.write_sql(output)

	def write_ssql(self, output):
		self.clear_tags()
		for vtable in self.get_tables().values():
			if not vtable.tag:
				self._write_ssql(vtable, output)

	def _write_ssql(self, vtable, output):
		vtable.tag = 1
		for (label, svrel) in vtable.items(): # svrel : set of relations
			if self._label_dir(label) == '1':
				for vrel in svrel.values():
					for (label, svt) in vrel.items():
						if self._label_dir(label) == '0':
							for t in svt.values():
								if not t.tag:
									self._write_ssql(t, output)
		vtable.object.write_ssql(output)

		
