Hi,
I am using HibernateUtil in my DAO layer which in tuen is accessed by session bean and MDB's. I wanted to isolate all calls to Hibernate (and HibernateUtil) to the DAO layer. The problem was that if a dao called another dao then session would be opened and closed once in each dao. I need support for nested getSession calls in the same thread. Hence I updated HibernateUtil to keep the count of the getSession calls in the same thread, before session is finally closed. This allows me to call getSession a number of times and close the session be assured the I get the same session back until the final close session is called.
I also implmented the same logic for begin and commit transaction.
Here is the code: Please provide feedback.
Code:
package com.panacya.pcyservice.hibernate;
import net.sf.hibernate.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import javax.naming.Context;
import javax.naming.InitialContext;
import javax.naming.NamingException;
/**
* Basic Hibernate helper class, handles SessionFactory, Session and Transaction.
* <p>
* Uses a static initializer for the initial SessionFactory creation
* and holds Session and Transactions in thread local variables. All
* exceptions are wrapped in an unchecked HibernateServiceException.
*
* @author christian@hibernate.org
*/
public class HibernateUtil
{
/** Logger instance */
private static Log _log = LogFactory.getLog(HibernateUtil.class);
/**Session Factory */
private static SessionFactory _sessionFactory = null;
/** Thread local session */
private static final ThreadLocal _threadSession = new ThreadLocal();
/**Thread local transaction */
private static final ThreadLocal _threadTransaction = new ThreadLocal();
/**Thread local interceptor */
private static final ThreadLocal _threadInterceptor = new ThreadLocal();
/**Thread local session retrieval count */
private static final ThreadLocal _threadSessionGetCount = new ThreadLocal();
/** Session factory jndi name */
public static final String HIBERNATE_FACTORY_JNDI_NAME = "java:hibernate/HibernateFactory";
/** Thread Local transaction retrieval count */
private static final ThreadLocal _threadTxnBeginCount = new ThreadLocal();
/**
* Get the session factory
*
* @return gets the session factory
*
* @throws HibernateServiceException if error occurs getting the session factory
*/
public static SessionFactory getSessionFactory() throws HibernateServiceException
{
if(_sessionFactory == null)
{
updateSessionFactory();
}
return _sessionFactory;
}
/**
* updates the session factory reference
*
* @throws HibernateServiceException if error occurs looking up the session factory
*/
private static void updateSessionFactory() throws HibernateServiceException
{
try
{
Context context = new InitialContext();
_sessionFactory = (SessionFactory) context.lookup(HIBERNATE_FACTORY_JNDI_NAME);
}
catch(NamingException e)
{
throw new HibernateServiceException("Cannot create SessionFactory", e);
}
}
/**
* Rebuild the SessionFactory with the static Configuration.
*
*/
public static void rebuildSessionFactory() throws HibernateServiceException
{
if(_sessionFactory != null)
{
_sessionFactory = null;
}
updateSessionFactory();
}
/**
* Gets the hibernate session
*
* @return hibernate session
*
* @throws HibernateServiceException if error occurs getting the session
*/
public static Session getSession() throws HibernateServiceException
{
if(_log.isTraceEnabled())
{
String callingMethod = getCallingMethod();
_log.trace("Getting Session[" + getThreadSessionGetCount() + "] " + callingMethod);
}
Session session = getSessionInternal();
incrementSessionGetCount();
return session;
}
/**
* Returns the method that called the HibernateUtil method. Useful for debugging and tracing
*
* @return Returns the method name and line number.
*/
private static String getCallingMethod()
{
StackTraceElement[] stackTraceElements = (new Exception().getStackTrace());
StackTraceElement stackTraceElement = stackTraceElements[2];
String callingMethod =
stackTraceElement.getFileName() + ":" + stackTraceElement.getMethodName() + "():"
+ stackTraceElement.getLineNumber();
return callingMethod;
}
/**
* Gets a new session object from the session factory if the thread local session is null
*
* @return hibernate session
*
* @throws HibernateServiceException if error occurs getting the session 3
*/
private static Session getSessionInternal() throws HibernateServiceException
{
Session session = (Session) _threadSession.get();
try
{
if(session == null)
{
_log.debug("Opening new Session for this thread.");
if(getInterceptor() != null)
{
_log.debug("Using interceptor: " + getInterceptor().getClass());
session = getSessionFactory().openSession(getInterceptor());
}
else
{
session = getSessionFactory().openSession();
}
_threadSession.set(session);
resetSessionGetCount();
}
}
catch(HibernateException ex)
{
throw new HibernateServiceException(ex);
}
return session;
}
/**
* Sets the thread local session retrival count to 0
*/
private static void resetSessionGetCount()
{
_threadSessionGetCount.set(new Integer(0));
}
/**
* Increments the thread local session retrival count
*/
private static void incrementSessionGetCount()
{
int sessionCountValue = getThreadSessionGetCount();
Integer sessionGetCount = null;
sessionCountValue++;
sessionGetCount = new Integer(sessionCountValue);
_threadSessionGetCount.set(sessionGetCount);
if(_log.isDebugEnabled())
{
_log.debug("Incrementing session get count to[" + sessionGetCount + "]");
}
}
/**
* Gets the current thread local retrival couont
*
* @return the thread local retrival count
*/
private static int getThreadSessionGetCount()
{
Integer sessionGetCount = (Integer) _threadSessionGetCount.get();
if(sessionGetCount == null)
{
sessionGetCount = new Integer(0);
}
int sessionCountValue = sessionGetCount.intValue();
return sessionCountValue;
}
/**
* Closes the Session local to the thread.
*/
public static void closeSession() throws HibernateServiceException
{
if(_log.isTraceEnabled())
{
_log.trace("Closing Session[" + getThreadSessionGetCount() + "] " + getCallingMethod());
}
try
{
decrementSessionGetCount();
if(getThreadSessionGetCount() == 0)
{
closeSessionInternal();
}
}
catch(HibernateException ex)
{
throw new HibernateServiceException(ex);
}
}
/**
* closes the session
*
* @throws HibernateException if error occurs closing the session
*/
private static void closeSessionInternal() throws HibernateException
{
Session s = (Session) _threadSession.get();
_threadSession.set(null);
resetSessionGetCount();
if((s != null) && s.isOpen())
{
_log.debug("Closing Session of this thread.");
s.close();
}
}
/**
* decrements the session retrival count
*/
private static void decrementSessionGetCount() throws HibernateServiceException
{
int sessionGetCountValue = getThreadSessionGetCount();
if(sessionGetCountValue <= 0)
{
_log.warn("Decrement get session count called when count <=0 " + getCallingMethod());
}
Integer sessionGetCount = null;
sessionGetCountValue--;
sessionGetCount = new Integer(sessionGetCountValue);
_threadSessionGetCount.set(sessionGetCount);
if(_log.isDebugEnabled())
{
_log.debug("Decrementing session get count to[" + sessionGetCount + "]");
}
}
/**
* Start a new database transaction.
*/
public static void beginTransaction() throws HibernateServiceException
{
if(_log.isTraceEnabled())
{
_log.trace("Beggining a transaction " + getCallingMethod());
}
Transaction tx = (Transaction) _threadTransaction.get();
try
{
if(tx == null)
{
_log.debug("Starting new database transaction in this thread.");
Session session = (Session) _threadSession.get();
if(session != null)
{
tx = getSessionInternal().beginTransaction();
_threadTransaction.set(tx);
}
else
{
throw new HibernateServiceException("Session is null, open a session first: {call getSession()}");
}
}
}
catch(HibernateException ex)
{
throw new HibernateServiceException(ex);
}
incrementTxnBeginCount();
}
/**
* Increments the thread local transaction begin count
*/
private static void incrementTxnBeginCount()
{
int txnCountValue = getThreadTxnBeginCount();
Integer txnGetCount = null;
txnCountValue++;
txnGetCount = new Integer(txnCountValue);
_threadTxnBeginCount.set(txnGetCount);
if(_log.isDebugEnabled())
{
_log.debug("Incrementing txn get count to[" + txnGetCount + "]");
}
}
/**
* gets the thread local current transaction begin count
*
* @return the thread local transaction begin count
*/
private static int getThreadTxnBeginCount()
{
Integer txnGetCount = (Integer) _threadTxnBeginCount.get();
if(txnGetCount == null)
{
txnGetCount = new Integer(0);
}
int txnCountValue = txnGetCount.intValue();
return txnCountValue;
}
/**
* decrements the thread local transaction begin count after a transaction is committed
*/
private static void decrementTxnBeginCount() throws HibernateServiceException
{
int txnGetCountValue = getThreadTxnBeginCount();
if(txnGetCountValue <= 0)
{
_log.warn("Decrement get txn count called when count <=0 " + getCallingMethod());
}
Integer txnGetCount = null;
txnGetCountValue--;
txnGetCount = new Integer(txnGetCountValue);
_threadTxnBeginCount.set(txnGetCount);
if(_log.isDebugEnabled())
{
_log.debug("Decrementing txn get count to[" + txnGetCount + "]");
}
}
/**
* Commit the database transaction.
*/
public static void commitTransaction() throws HibernateServiceException
{
if(_log.isTraceEnabled())
{
_log.trace("Commiting a transaction " + getCallingMethod());
}
Transaction tx = (Transaction) _threadTransaction.get();
try
{
decrementTxnBeginCount();
if((getThreadTxnBeginCount() <= 0) && (tx != null) && !tx.wasCommitted() && !tx.wasRolledBack())
{
_log.debug("Committing database transaction of this thread.");
tx.commit();
_threadTransaction.set(null);
_threadTxnBeginCount.set(null);
}
}
catch(HibernateException ex)
{
rollbackTransaction();
throw new HibernateServiceException(ex);
}
}
/**
* Rollback the database transaction.
*/
public static void rollbackTransaction() throws HibernateServiceException
{
if(_log.isTraceEnabled())
{
_log.trace("Rolling back transaction " + getCallingMethod());
}
Transaction tx = (Transaction) _threadTransaction.get();
try
{
_threadTransaction.set(null);
if((tx != null) && !tx.wasCommitted() && !tx.wasRolledBack())
{
_log.debug("Tyring to rollback database transaction of this thread.");
tx.rollback();
}
}
catch(HibernateException ex)
{
throw new HibernateServiceException(ex);
}
finally
{
try
{
if(_log.isDebugEnabled())
{
_log.debug("Transaction is rolled back. Closing session forcibly");
}
closeSessionInternal();
}
catch(HibernateException e)
{
throw new HibernateServiceException(e);
}
}
}
/**
* Register a Hibernate interceptor with the current thread.
* <p>
* Every Session opened is opened with this interceptor after
* registration. Has no effect if the current Session of the
* thread is already open, effective on next close()/getSession().
*/
public static void registerInterceptor(Interceptor interceptor)
{
_threadInterceptor.set(interceptor);
}
/**
* Gets the interceptor
*
* @return the interceptor
*/
public static Interceptor getInterceptor()
{
Interceptor interceptor = (Interceptor) _threadInterceptor.get();
return interceptor;
}
/**
* Sets the session factory
*
* @param sessionFactory session factory
*/
public static void setSessionFactory(SessionFactory sessionFactory)
{
_sessionFactory = sessionFactory;
}
}
Thanks