Here is my solution so far: if someone with a greater knowledge of hibernate's internals could take a look at it and make it safe, that would be great!
Code:
Code:
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
@Retention(RUNTIME)
@Target( {METHOD, FIELD})
public @interface CascadedVersion {}
Code:
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.hibernate.EntityMode;
import org.hibernate.LockMode;
import org.hibernate.annotations.common.reflection.ReflectionManager;
import org.hibernate.annotations.common.reflection.XClass;
import org.hibernate.annotations.common.reflection.XProperty;
import org.hibernate.annotations.common.reflection.java.JavaReflectionManager;
import org.hibernate.cfg.Configuration;
import org.hibernate.engine.EntityEntry;
import org.hibernate.engine.PersistenceContext;
import org.hibernate.event.EventSource;
import org.hibernate.event.Initializable;
import org.hibernate.event.PreUpdateEvent;
import org.hibernate.event.PreUpdateEventListener;
import org.hibernate.mapping.PersistentClass;
import org.hibernate.mapping.Property;
import org.hibernate.persister.entity.EntityPersister;
public class CascadedVersionEventListener implements PreUpdateEventListener, Initializable {
    private static final ReflectionManager manager = new JavaReflectionManager();
    private static final long serialVersionUID = 1L;
    private final Map<String, Set<String>> cascadeProperties = new HashMap<String, Set<String>>();
    public void initialize(Configuration configuration) {
        Iterator<?> persistentClassIterator = configuration.getClassMappings();
        while (persistentClassIterator.hasNext()) {
            PersistentClass persistentClass = PersistentClass.class.cast(persistentClassIterator.next());
            Set<String> properties = new HashSet<String>();
            Iterator<?> propertyIterator = persistentClass.getReferenceablePropertyIterator();
            while (propertyIterator.hasNext()) {
                Property property = Property.class.cast(propertyIterator.next());
                if (isCascadedVersion(property)) {
                    properties.add(property.getName());
                }
            }
            this.cascadeProperties.put(persistentClass.getEntityName(), properties);
        }
    }
    public boolean onPreUpdate(PreUpdateEvent event) {
        Set<Object> checkedEntities = new HashSet<Object>();
        Queue<Object> entitiesToCheck = new LinkedList<Object>();
        entitiesToCheck.add(event.getEntity());
        while (!entitiesToCheck.isEmpty()) {
            Object entity = entitiesToCheck.poll();
            String entityName = entity.getClass().getName();
            if (checkedEntities.add(entity)) {
                EventSource source = event.getSession();
                PersistenceContext context = source.getPersistenceContext();
                EntityPersister persister = source.getEntityPersister(entityName, entity);
                EntityMode mode = persister.guessEntityMode(entity);
                EntityEntry entry = context.getEntry(entity);
                Object[] oldState = entry.getLoadedState();
                Object[] newState = persister.getPropertyValues(entity, mode);
                boolean dirty = persister.findDirty(oldState, newState, entity, event.getSession()) != null;
                boolean versionable = persister.getVersionProperty() >= 0;
                if (versionable && !dirty) {
                    Serializable id = persister.getIdentifier(entity, mode);
                    if (entity != event.getEntity()) {
                        PreUpdateEvent event2 = new PreUpdateEvent(entity, id, newState, oldState, persister, source);
                        // force hibernate validator here...
                        for (PreUpdateEventListener listener : source.getListeners().getPreUpdateEventListeners()) {
                            listener.onPreUpdate(event2);
                        }
                        source.lock(entity, LockMode.FORCE);
                    }
                }
                for (String propertyName : this.cascadeProperties.get(entityName)) {
                    entitiesToCheck.offer(persister.getPropertyValue(entity, propertyName, mode));
                }
            }
        }
        return false;
    }
    private boolean isCascadedVersion(Property property) {
        Class<?> mappedClass = property.getPersistentClass().getMappedClass();
        XClass xClass = manager.toXClass(mappedClass);
        for (XProperty xProperty : xClass.getDeclaredProperties(property.getPropertyAccessorName())) {
            if (xProperty.getName().equals(property.getName()) && xProperty.isAnnotationPresent(CascadedVersion.class)) {
                return true;
            }
        }
        return false;
    }
}
Test:Code:
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.persistence.Entity;
import javax.persistence.FetchType;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.OneToMany;
import javax.persistence.Table;
import javax.persistence.Version;
import org.hibernate.annotations.Cascade;
import org.hibernate.annotations.CascadeType;
import org.hibernate.validator.AssertTrue;
@Entity
@Table(name = "ORDERS")
public class Order {
    @Id
    @GeneratedValue
    Long id;
    @Version
    Integer version;
    @OneToMany(mappedBy = "order", fetch = FetchType.EAGER)
    @Cascade( {CascadeType.ALL, CascadeType.DELETE_ORPHAN})
    private Set<OrderLine> lines = new HashSet<OrderLine>();
    public void add(Product product, int quantity) {
        Iterator<OrderLine> iterator = this.lines.iterator();
        while (iterator.hasNext()) {
            OrderLine line = iterator.next();
            if (line.getProduct().equals(product)) {
                line.addQuantity(quantity);
                if (line.getQuantity() <= 0) {
                    iterator.remove();
                }
                return;
            }
        }
        this.lines.add(new OrderLine(this, product, quantity));
    }
    public Set<OrderLine> getLines() {
        return Collections.unmodifiableSet(this.lines);
    }
    public double getTotal() {
        double total = 0;
        for (OrderLine line : this.lines) {
            total += line.getSubTotal();
        }
        return total;
    }
    @AssertTrue
    boolean checkTotal() {
        return getTotal() < 150;
    }
}
Code:
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.Table;
import org.hibernate.annotations.NaturalId;
@Entity
@Table(name = "ORDER_LINES")
public class OrderLine {
    @Id
    @GeneratedValue
    Long id;
    @ManyToOne
    @JoinColumn
    @NaturalId
    @CascadedVersion
    Order order;
    @ManyToOne
    @JoinColumn
    @NaturalId
    private Product product;
    private int quantity;
    protected OrderLine(Order order, Product product, int quantity) {
        this.order = order;
        this.quantity = quantity;
        this.product = product;
    }
    public boolean equals(Object that) {
        return super.equals(that) || getClass().isInstance(that) && equals(getClass().cast(that));
    }
    public Product getProduct() {
        return this.product;
    }
    public int getQuantity() {
        return this.quantity;
    }
    public double getSubTotal() {
        return this.quantity * this.product.getPrice();
    }
    public int hashCode() {
        return this.order.hashCode() ^ this.product.hashCode();
    }
    void addQuantity(int quantity) {
        this.quantity += quantity;
    }
    private boolean equals(OrderLine that) {
        return this.order.equals(that.order) && this.product.equals(that.product);
    }
}
Code:
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.Id;
import javax.persistence.Table;
import javax.persistence.Version;
import org.hibernate.annotations.NaturalId;
@Entity
@Table(name = "PRODUCTS")
public class Product {
    @Id
    @GeneratedValue
    Long id;
    @Version
    Integer version;
    @NaturalId
    private String name;
    private double price;
    public Product(String name, double price) {
        this.name = name;
        this.price = price;
    }
    public boolean equals(Object that) {
        return super.equals(that) || getClass().isInstance(that) && equals(getClass().cast(that));
    }
    public String getName() {
        return this.name;
    }
    public double getPrice() {
        return this.price;
    }
    public int hashCode() {
        return this.name.hashCode();
    }
    private boolean equals(Product that) {
        return this.name.equals(that.name);
    }
}
Code:
import javax.persistence.EntityManager;
import javax.persistence.EntityManagerFactory;
import javax.persistence.Persistence;
public class Main {
    public static void main(String[] args) {
        EntityManagerFactory factory = Persistence.createEntityManagerFactory("default");
        // Create initial data: 2 products + 1 order
        EntityManager manager = factory.createEntityManager();
        Product product1 = new Product("hibernate in action", 45f);
        Product product2 = new Product("java persistence with hibernate", 50f);
        Order order = new Order();
        order.add(product1, 1);
        order.add(product2, 1);
        manager.getTransaction().begin();
        manager.persist(product1);
        manager.persist(product2);
        manager.persist(order);
        manager.getTransaction().commit();
        manager.close();
        // Reload the order 2x within different managers
        EntityManager manager1 = factory.createEntityManager();
        EntityManager manager2 = factory.createEntityManager();
        Order order1 = manager1.find(Order.class, 1L);
        Order order2 = manager2.find(Order.class, 1L);
        // Modify the order concurrently
        order1.add(product1, 1);
        order2.add(product2, 1);
        // Commit the concurrent changes
        manager1.getTransaction().begin();
        manager2.getTransaction().begin();
        manager1.getTransaction().commit();
        try {
            manager2.getTransaction().commit();
        } catch (RuntimeException e) {
            System.err.println(e.getMessage());
        }
        manager1.close();
        manager2.close();
        // Reload the order one last time within a new manager
        EntityManager manager3 = factory.createEntityManager();
        Order order3 = manager3.find(Order.class, 1L);
        manager3.close();
        // Assert the total is < 150
        System.out.println("total price: " + order3.getTotal());
        System.out.println("invariant status: " + (order3.getTotal() < 150 ? "OK" : "KO"));
    }
}
in persistence.xml
Code:
<property name="hibernate.ejb.event.pre-update" value="CascadedVersionEventListener" />
Add or remove the CascadedVersion on OrderLine.order to see the effect.
I hope it'll help...
Xavier.