1 16 17 package org.springframework.test.jpa; 18 19 import java.lang.instrument.ClassFileTransformer ; 20 import java.lang.reflect.Constructor ; 21 import java.lang.reflect.Field ; 22 import java.lang.reflect.InvocationTargetException ; 23 import java.lang.reflect.Method ; 24 import java.util.HashMap ; 25 import java.util.Map ; 26 27 import javax.persistence.EntityManager; 28 import javax.persistence.EntityManagerFactory; 29 30 import junit.framework.TestCase; 31 32 import org.springframework.beans.BeanUtils; 33 import org.springframework.beans.BeansException; 34 import org.springframework.beans.factory.config.BeanPostProcessor; 35 import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessorAdapter; 36 import org.springframework.beans.factory.support.BeanDefinitionRegistry; 37 import org.springframework.beans.factory.support.DefaultListableBeanFactory; 38 import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; 39 import org.springframework.context.ConfigurableApplicationContext; 40 import org.springframework.context.support.GenericApplicationContext; 41 import org.springframework.instrument.classloading.LoadTimeWeaver; 42 import org.springframework.instrument.classloading.ResourceOverridingShadowingClassLoader; 43 import org.springframework.instrument.classloading.ShadowingClassLoader; 44 import org.springframework.orm.jpa.ExtendedEntityManagerCreator; 45 import org.springframework.orm.jpa.LocalContainerEntityManagerFactoryBean; 46 import org.springframework.orm.jpa.SharedEntityManagerCreator; 47 import org.springframework.orm.jpa.persistenceunit.DefaultPersistenceUnitManager; 48 import org.springframework.test.annotation.AbstractAnnotationAwareTransactionalTests; 49 import org.springframework.util.StringUtils; 50 51 64 public abstract class AbstractJpaTests extends AbstractAnnotationAwareTransactionalTests { 65 66 private static final String DEFAULT_ORM_XML_LOCATION = "META-INF/orm.xml"; 67 68 73 private static Map <String , Object > contextCache = new HashMap <String , Object >(); 74 75 private static Map <String , ClassLoader > classLoaderCache = new HashMap <String , ClassLoader >(); 76 77 protected EntityManagerFactory entityManagerFactory; 78 79 88 private Object shadowParent; 89 90 94 protected EntityManager sharedEntityManager; 95 96 97 public void setEntityManagerFactory(EntityManagerFactory entityManagerFactory) { 98 this.entityManagerFactory = entityManagerFactory; 99 this.sharedEntityManager = SharedEntityManagerCreator.createSharedEntityManager(this.entityManagerFactory); 100 } 101 102 108 protected EntityManager createContainerManagedEntityManager() { 109 return ExtendedEntityManagerCreator.createContainerManagedEntityManager(this.entityManagerFactory); 110 } 111 112 119 protected boolean shouldUseShadowLoader() { 120 return true; 121 } 122 123 @Override 124 public void setDirty() { 125 super.setDirty(); 126 contextCache.remove(cacheKeys()); 127 classLoaderCache.remove(cacheKeys()); 128 129 if (this.shadowParent != null) { 134 try { 135 Method m = shadowParent.getClass().getMethod("setDirty", (Class []) null); 136 m.invoke(shadowParent, (Object []) null); 137 } 138 catch (Exception ex) { 139 throw new RuntimeException (ex); 140 } 141 } 142 } 143 144 145 @Override 146 public void runBare() throws Throwable { 147 if (!shouldUseShadowLoader()) { 148 super.runBare(); 149 return; 150 } 151 152 String combinationOfContextLocationsForThisTestClass = cacheKeys(); 153 ClassLoader classLoaderForThisTestClass = getClass().getClassLoader(); 154 ClassLoader initialClassLoader = Thread.currentThread().getContextClassLoader(); 156 157 if (this.shadowParent != null) { 158 Thread.currentThread().setContextClassLoader(classLoaderForThisTestClass); 159 super.runBare(); 160 } 161 else { 162 ShadowingClassLoader shadowingClassLoader = (ShadowingClassLoader) classLoaderCache.get(combinationOfContextLocationsForThisTestClass); 163 164 if (shadowingClassLoader == null) { 165 shadowingClassLoader = (ShadowingClassLoader) createShadowingClassLoader(classLoaderForThisTestClass); 166 classLoaderCache.put(combinationOfContextLocationsForThisTestClass, shadowingClassLoader); 167 } 168 try { 169 Thread.currentThread().setContextClassLoader(shadowingClassLoader); 170 String [] configLocations = getConfigLocations(); 171 172 Object cachedContext = contextCache.get(combinationOfContextLocationsForThisTestClass); 174 175 if (cachedContext == null) { 176 177 Class shadowingLoadTimeWeaverClass = shadowingClassLoader.loadClass(ShadowingLoadTimeWeaver.class.getName()); 179 Constructor constructor = shadowingLoadTimeWeaverClass.getConstructor(ClassLoader .class); 180 constructor.setAccessible(true); 181 Object ltw = constructor.newInstance(shadowingClassLoader); 182 183 Class beanFactoryClass = shadowingClassLoader.loadClass(DefaultListableBeanFactory.class.getName()); 185 Object beanFactory = BeanUtils.instantiateClass(beanFactoryClass); 186 187 Class beanDefinitionReaderClass = shadowingClassLoader.loadClass(XmlBeanDefinitionReader.class.getName()); 189 Class beanDefinitionRegistryClass = shadowingClassLoader.loadClass(BeanDefinitionRegistry.class.getName()); 190 Object reader = beanDefinitionReaderClass.getConstructor(beanDefinitionRegistryClass).newInstance(beanFactory); 191 192 Method loadBeanDefinitions = beanDefinitionReaderClass.getMethod("loadBeanDefinitions", String [].class); 194 loadBeanDefinitions.invoke(reader, new Object []{configLocations}); 195 196 Class loadTimeWeaverInjectingBeanPostProcessorClass = shadowingClassLoader.loadClass(LoadTimeWeaverInjectingBeanPostProcessor.class.getName()); 198 Class loadTimeWeaverClass = shadowingClassLoader.loadClass(LoadTimeWeaver.class.getName()); 199 Constructor bppConstructor = loadTimeWeaverInjectingBeanPostProcessorClass.getConstructor(loadTimeWeaverClass); 200 bppConstructor.setAccessible(true); 201 Object beanPostProcessor = bppConstructor.newInstance(ltw); 202 203 Class beanPostProcessorClass = shadowingClassLoader.loadClass(BeanPostProcessor.class.getName()); 205 Method addBeanPostProcessor = beanFactoryClass.getMethod("addBeanPostProcessor", beanPostProcessorClass); 206 addBeanPostProcessor.invoke(beanFactory, beanPostProcessor); 207 208 Class genericApplicationContextClass = shadowingClassLoader.loadClass(GenericApplicationContext.class.getName()); 210 Class defaultListableBeanFactoryClass = shadowingClassLoader.loadClass(DefaultListableBeanFactory.class.getName()); 211 cachedContext = genericApplicationContextClass.getConstructor(defaultListableBeanFactoryClass).newInstance(beanFactory); 212 213 genericApplicationContextClass.getMethod("refresh").invoke(cachedContext); 215 216 contextCache.put(combinationOfContextLocationsForThisTestClass, cachedContext); 218 } 219 Class shadowedTestClass = shadowingClassLoader.loadClass(getClass().getName()); 221 222 TestCase shadowedTestCase = (TestCase) BeanUtils.instantiateClass(shadowedTestClass); 225 226 227 Class thisShadowedClass = shadowingClassLoader.loadClass(AbstractJpaTests.class.getName()); 228 Field shadowed = thisShadowedClass.getDeclaredField("shadowParent"); 229 shadowed.setAccessible(true); 230 shadowed.set(shadowedTestCase, this); 231 232 233 Class applicationContextClass = shadowingClassLoader.loadClass(ConfigurableApplicationContext.class.getName()); 234 Method addContextMethod = shadowedTestClass.getMethod("addContext", Object .class, applicationContextClass); 235 addContextMethod.invoke(shadowedTestCase, configLocations, cachedContext); 236 237 shadowedTestCase.setName(getName()); 239 shadowedTestCase.runBare(); 240 } 241 catch (InvocationTargetException ex) { 242 throw ex.getTargetException(); 245 } 246 finally { 247 Thread.currentThread().setContextClassLoader(initialClassLoader); 248 } 249 } 250 } 251 252 protected String cacheKeys() { 253 return StringUtils.arrayToCommaDelimitedString(getConfigLocations()); 254 } 255 256 260 protected ClassLoader createShadowingClassLoader(ClassLoader classLoader) { 261 OrmXmlOverridingShadowingClassLoader orxl = new OrmXmlOverridingShadowingClassLoader(classLoader, 262 getActualOrmXmlLocation()); 263 customizeResourceOverridingShadowingClassLoader(orxl); 264 return orxl; 265 } 266 267 274 protected void customizeResourceOverridingShadowingClassLoader(ClassLoader shadowingClassLoader) { 275 } 277 278 283 protected String getActualOrmXmlLocation() { 284 return DEFAULT_ORM_XML_LOCATION; 285 } 286 287 288 private static class LoadTimeWeaverInjectingBeanPostProcessor extends InstantiationAwareBeanPostProcessorAdapter { 289 290 private final LoadTimeWeaver ltw; 291 292 public LoadTimeWeaverInjectingBeanPostProcessor(LoadTimeWeaver ltw) { 293 this.ltw = ltw; 294 } 295 296 public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { 297 if (bean instanceof LocalContainerEntityManagerFactoryBean) { 298 ((LocalContainerEntityManagerFactoryBean) bean).setLoadTimeWeaver(this.ltw); 299 } 300 if (bean instanceof DefaultPersistenceUnitManager) { 301 ((DefaultPersistenceUnitManager) bean).setLoadTimeWeaver(this.ltw); 302 } 303 return bean; 304 } 305 } 306 307 308 private static class ShadowingLoadTimeWeaver implements LoadTimeWeaver { 309 310 private final ClassLoader shadowingClassLoader; 311 312 private final Class shadowingClassLoaderClass; 313 314 public ShadowingLoadTimeWeaver(ClassLoader shadowingClassLoader) { 315 this.shadowingClassLoader = shadowingClassLoader; 316 this.shadowingClassLoaderClass = shadowingClassLoader.getClass(); 317 } 318 319 public ClassLoader getInstrumentableClassLoader() { 320 return (ClassLoader ) this.shadowingClassLoader; 321 } 322 323 public ClassLoader getThrowawayClassLoader() { 324 ResourceOverridingShadowingClassLoader roscl = new ResourceOverridingShadowingClassLoader(getClass().getClassLoader()); 329 if (shadowingClassLoader instanceof ResourceOverridingShadowingClassLoader) { 330 roscl.copyOverrides((ResourceOverridingShadowingClassLoader) shadowingClassLoader); 331 } 332 if (shadowingClassLoader instanceof ShadowingClassLoader) { 333 roscl.copyTransformers((ShadowingClassLoader) shadowingClassLoader); 334 } 335 return roscl; 336 } 337 338 public void addTransformer(ClassFileTransformer transformer) { 339 try { 340 Method addClassFileTransformer = 341 this.shadowingClassLoaderClass.getMethod("addTransformer", ClassFileTransformer .class); 342 addClassFileTransformer.setAccessible(true); 343 addClassFileTransformer.invoke(this.shadowingClassLoader, transformer); 344 } 345 catch (Exception ex) { 346 throw new RuntimeException (ex); 347 } 348 } 349 } 350 351 } 352 | Popular Tags |