Here is my solution so far: if someone with a greater knowledge of hibernate's internals could take a look at it and make it safe, that would be great!
Code:
Code:
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
@Retention(RUNTIME)
@Target( {METHOD, FIELD})
public @interface CascadedVersion {}
Code:
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.hibernate.EntityMode;
import org.hibernate.LockMode;
import org.hibernate.annotations.common.reflection.ReflectionManager;
import org.hibernate.annotations.common.reflection.XClass;
import org.hibernate.annotations.common.reflection.XProperty;
import org.hibernate.annotations.common.reflection.java.JavaReflectionManager;
import org.hibernate.cfg.Configuration;
import org.hibernate.engine.EntityEntry;
import org.hibernate.engine.PersistenceContext;
import org.hibernate.event.EventSource;
import org.hibernate.event.Initializable;
import org.hibernate.event.PreUpdateEvent;
import org.hibernate.event.PreUpdateEventListener;
import org.hibernate.mapping.PersistentClass;
import org.hibernate.mapping.Property;
import org.hibernate.persister.entity.EntityPersister;
public class CascadedVersionEventListener implements PreUpdateEventListener, Initializable {
private static final ReflectionManager manager = new JavaReflectionManager();
private static final long serialVersionUID = 1L;
private final Map<String, Set<String>> cascadeProperties = new HashMap<String, Set<String>>();
public void initialize(Configuration configuration) {
Iterator<?> persistentClassIterator = configuration.getClassMappings();
while (persistentClassIterator.hasNext()) {
PersistentClass persistentClass = PersistentClass.class.cast(persistentClassIterator.next());
Set<String> properties = new HashSet<String>();
Iterator<?> propertyIterator = persistentClass.getReferenceablePropertyIterator();
while (propertyIterator.hasNext()) {
Property property = Property.class.cast(propertyIterator.next());
if (isCascadedVersion(property)) {
properties.add(property.getName());
}
}
this.cascadeProperties.put(persistentClass.getEntityName(), properties);
}
}
public boolean onPreUpdate(PreUpdateEvent event) {
Set<Object> checkedEntities = new HashSet<Object>();
Queue<Object> entitiesToCheck = new LinkedList<Object>();
entitiesToCheck.add(event.getEntity());
while (!entitiesToCheck.isEmpty()) {
Object entity = entitiesToCheck.poll();
String entityName = entity.getClass().getName();
if (checkedEntities.add(entity)) {
EventSource source = event.getSession();
PersistenceContext context = source.getPersistenceContext();
EntityPersister persister = source.getEntityPersister(entityName, entity);
EntityMode mode = persister.guessEntityMode(entity);
EntityEntry entry = context.getEntry(entity);
Object[] oldState = entry.getLoadedState();
Object[] newState = persister.getPropertyValues(entity, mode);
boolean dirty = persister.findDirty(oldState, newState, entity, event.getSession()) != null;
boolean versionable = persister.getVersionProperty() >= 0;
if (versionable && !dirty) {
Serializable id = persister.getIdentifier(entity, mode);
if (entity != event.getEntity()) {
PreUpdateEvent event2 = new PreUpdateEvent(entity, id, newState, oldState, persister, source);
// force hibernate validator here...
for (PreUpdateEventListener listener : source.getListeners().getPreUpdateEventListeners()) {
listener.onPreUpdate(event2);
}
source.lock(entity, LockMode.FORCE);
}
}
for (String propertyName : this.cascadeProperties.get(entityName)) {
entitiesToCheck.offer(persister.getPropertyValue(entity, propertyName, mode));
}
}
}
return false;
}
private boolean isCascadedVersion(Property property) {
Class<?> mappedClass = property.getPersistentClass().getMappedClass();
XClass xClass = manager.toXClass(mappedClass);
for (XProperty xProperty : xClass.getDeclaredProperties(property.getPropertyAccessorName())) {
if (xProperty.getName().equals(property.getName()) && xProperty.isAnnotationPresent(CascadedVersion.class)) {
return true;
}
}
return false;
}
}
Test:Code:
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.persistence.Entity;
import javax.persistence.FetchType;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.OneToMany;
import javax.persistence.Table;
import javax.persistence.Version;
import org.hibernate.annotations.Cascade;
import org.hibernate.annotations.CascadeType;
import org.hibernate.validator.AssertTrue;
@Entity
@Table(name = "ORDERS")
public class Order {
@Id
@GeneratedValue
Long id;
@Version
Integer version;
@OneToMany(mappedBy = "order", fetch = FetchType.EAGER)
@Cascade( {CascadeType.ALL, CascadeType.DELETE_ORPHAN})
private Set<OrderLine> lines = new HashSet<OrderLine>();
public void add(Product product, int quantity) {
Iterator<OrderLine> iterator = this.lines.iterator();
while (iterator.hasNext()) {
OrderLine line = iterator.next();
if (line.getProduct().equals(product)) {
line.addQuantity(quantity);
if (line.getQuantity() <= 0) {
iterator.remove();
}
return;
}
}
this.lines.add(new OrderLine(this, product, quantity));
}
public Set<OrderLine> getLines() {
return Collections.unmodifiableSet(this.lines);
}
public double getTotal() {
double total = 0;
for (OrderLine line : this.lines) {
total += line.getSubTotal();
}
return total;
}
@AssertTrue
boolean checkTotal() {
return getTotal() < 150;
}
}
Code:
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.Table;
import org.hibernate.annotations.NaturalId;
@Entity
@Table(name = "ORDER_LINES")
public class OrderLine {
@Id
@GeneratedValue
Long id;
@ManyToOne
@JoinColumn
@NaturalId
@CascadedVersion
Order order;
@ManyToOne
@JoinColumn
@NaturalId
private Product product;
private int quantity;
protected OrderLine(Order order, Product product, int quantity) {
this.order = order;
this.quantity = quantity;
this.product = product;
}
public boolean equals(Object that) {
return super.equals(that) || getClass().isInstance(that) && equals(getClass().cast(that));
}
public Product getProduct() {
return this.product;
}
public int getQuantity() {
return this.quantity;
}
public double getSubTotal() {
return this.quantity * this.product.getPrice();
}
public int hashCode() {
return this.order.hashCode() ^ this.product.hashCode();
}
void addQuantity(int quantity) {
this.quantity += quantity;
}
private boolean equals(OrderLine that) {
return this.order.equals(that.order) && this.product.equals(that.product);
}
}
Code:
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.Table;
import javax.persistence.Version;
import org.hibernate.annotations.NaturalId;
@Entity
@Table(name = "PRODUCTS")
public class Product {
@Id
@GeneratedValue
Long id;
@Version
Integer version;
@NaturalId
private String name;
private double price;
public Product(String name, double price) {
this.name = name;
this.price = price;
}
public boolean equals(Object that) {
return super.equals(that) || getClass().isInstance(that) && equals(getClass().cast(that));
}
public String getName() {
return this.name;
}
public double getPrice() {
return this.price;
}
public int hashCode() {
return this.name.hashCode();
}
private boolean equals(Product that) {
return this.name.equals(that.name);
}
}
Code:
import javax.persistence.EntityManager;
import javax.persistence.EntityManagerFactory;
import javax.persistence.Persistence;
public class Main {
public static void main(String[] args) {
EntityManagerFactory factory = Persistence.createEntityManagerFactory("default");
// Create initial data: 2 products + 1 order
EntityManager manager = factory.createEntityManager();
Product product1 = new Product("hibernate in action", 45f);
Product product2 = new Product("java persistence with hibernate", 50f);
Order order = new Order();
order.add(product1, 1);
order.add(product2, 1);
manager.getTransaction().begin();
manager.persist(product1);
manager.persist(product2);
manager.persist(order);
manager.getTransaction().commit();
manager.close();
// Reload the order 2x within different managers
EntityManager manager1 = factory.createEntityManager();
EntityManager manager2 = factory.createEntityManager();
Order order1 = manager1.find(Order.class, 1L);
Order order2 = manager2.find(Order.class, 1L);
// Modify the order concurrently
order1.add(product1, 1);
order2.add(product2, 1);
// Commit the concurrent changes
manager1.getTransaction().begin();
manager2.getTransaction().begin();
manager1.getTransaction().commit();
try {
manager2.getTransaction().commit();
} catch (RuntimeException e) {
System.err.println(e.getMessage());
}
manager1.close();
manager2.close();
// Reload the order one last time within a new manager
EntityManager manager3 = factory.createEntityManager();
Order order3 = manager3.find(Order.class, 1L);
manager3.close();
// Assert the total is < 150
System.out.println("total price: " + order3.getTotal());
System.out.println("invariant status: " + (order3.getTotal() < 150 ? "OK" : "KO"));
}
}
in persistence.xml
Code:
<property name="hibernate.ejb.event.pre-update" value="CascadedVersionEventListener" />
Add or remove the CascadedVersion on OrderLine.order to see the effect.
I hope it'll help...
Xavier.