001: package org.hibernate.bytecode.javassist;
002:
003: import java.io.DataInputStream;
004: import java.io.DataOutputStream;
005: import java.io.File;
006: import java.io.FileInputStream;
007: import java.io.FileOutputStream;
008: import java.util.HashMap;
009: import java.util.Iterator;
010: import java.util.List;
011:
012: import javassist.CannotCompileException;
013: import javassist.bytecode.AccessFlag;
014: import javassist.bytecode.BadBytecode;
015: import javassist.bytecode.Bytecode;
016: import javassist.bytecode.ClassFile;
017: import javassist.bytecode.CodeAttribute;
018: import javassist.bytecode.CodeIterator;
019: import javassist.bytecode.ConstPool;
020: import javassist.bytecode.Descriptor;
021: import javassist.bytecode.FieldInfo;
022: import javassist.bytecode.MethodInfo;
023: import javassist.bytecode.Opcode;
024: import org.hibernate.bytecode.javassist.FieldFilter;
025: import org.hibernate.bytecode.javassist.FieldHandled;
026: import org.hibernate.bytecode.javassist.FieldHandler;
027:
028: /**
029: * The thing that handles actual class enhancement in regards to
030: * intercepting field accesses.
031: *
032: * @author Muga Nishizawa
033: */
034: public class FieldTransformer {
035:
036: private static final String EACH_READ_METHOD_PREFIX = "$javassist_read_";
037:
038: private static final String EACH_WRITE_METHOD_PREFIX = "$javassist_write_";
039:
040: private static final String FIELD_HANDLED_TYPE_NAME = FieldHandled.class
041: .getName();
042:
043: private static final String HANDLER_FIELD_NAME = "$JAVASSIST_READ_WRITE_HANDLER";
044:
045: private static final String FIELD_HANDLER_TYPE_NAME = FieldHandler.class
046: .getName();
047:
048: private static final String HANDLER_FIELD_DESCRIPTOR = 'L' + FIELD_HANDLER_TYPE_NAME
049: .replace('.', '/') + ';';
050:
051: private static final String GETFIELDHANDLER_METHOD_NAME = "getFieldHandler";
052:
053: private static final String SETFIELDHANDLER_METHOD_NAME = "setFieldHandler";
054:
055: private static final String GETFIELDHANDLER_METHOD_DESCRIPTOR = "()"
056: + HANDLER_FIELD_DESCRIPTOR;
057:
058: private static final String SETFIELDHANDLER_METHOD_DESCRIPTOR = "("
059: + HANDLER_FIELD_DESCRIPTOR + ")V";
060:
061: private FieldFilter filter;
062:
063: private HashMap readableFields;
064:
065: private HashMap writableFields;
066:
067: public FieldTransformer() {
068: this (null);
069: }
070:
071: public FieldTransformer(FieldFilter f) {
072: filter = f;
073: readableFields = new HashMap();
074: writableFields = new HashMap();
075: }
076:
077: public void setFieldFilter(FieldFilter f) {
078: filter = f;
079: }
080:
081: public void transform(File file) throws Exception {
082: DataInputStream in = new DataInputStream(new FileInputStream(
083: file));
084: ClassFile classfile = new ClassFile(in);
085: transform(classfile);
086: DataOutputStream out = new DataOutputStream(
087: new FileOutputStream(file));
088: try {
089: classfile.write(out);
090: } finally {
091: out.close();
092: }
093: }
094:
095: public void transform(ClassFile classfile) throws Exception {
096: if (classfile.isInterface()) {
097: return;
098: }
099: try {
100: addFieldHandlerField(classfile);
101: addGetFieldHandlerMethod(classfile);
102: addSetFieldHandlerMethod(classfile);
103: addFieldHandledInterface(classfile);
104: addReadWriteMethods(classfile);
105: transformInvokevirtualsIntoPutAndGetfields(classfile);
106: } catch (CannotCompileException e) {
107: throw new RuntimeException(e.getMessage(), e);
108: }
109: }
110:
111: private void addFieldHandlerField(ClassFile classfile)
112: throws CannotCompileException {
113: ConstPool cp = classfile.getConstPool();
114: FieldInfo finfo = new FieldInfo(cp, HANDLER_FIELD_NAME,
115: HANDLER_FIELD_DESCRIPTOR);
116: finfo.setAccessFlags(AccessFlag.PRIVATE | AccessFlag.TRANSIENT);
117: classfile.addField(finfo);
118: }
119:
120: private void addGetFieldHandlerMethod(ClassFile classfile)
121: throws CannotCompileException {
122: ConstPool cp = classfile.getConstPool();
123: int this _class_index = cp.getThisClassInfo();
124: MethodInfo minfo = new MethodInfo(cp,
125: GETFIELDHANDLER_METHOD_NAME,
126: GETFIELDHANDLER_METHOD_DESCRIPTOR);
127: /* local variable | this | */
128: Bytecode code = new Bytecode(cp, 2, 1);
129: // aload_0 // load this
130: code.addAload(0);
131: // getfield // get field "$JAVASSIST_CALLBACK" defined already
132: code.addOpcode(Opcode.GETFIELD);
133: int field_index = cp.addFieldrefInfo(this _class_index,
134: HANDLER_FIELD_NAME, HANDLER_FIELD_DESCRIPTOR);
135: code.addIndex(field_index);
136: // areturn // return the value of the field
137: code.addOpcode(Opcode.ARETURN);
138: minfo.setCodeAttribute(code.toCodeAttribute());
139: minfo.setAccessFlags(AccessFlag.PUBLIC);
140: classfile.addMethod(minfo);
141: }
142:
143: private void addSetFieldHandlerMethod(ClassFile classfile)
144: throws CannotCompileException {
145: ConstPool cp = classfile.getConstPool();
146: int this _class_index = cp.getThisClassInfo();
147: MethodInfo minfo = new MethodInfo(cp,
148: SETFIELDHANDLER_METHOD_NAME,
149: SETFIELDHANDLER_METHOD_DESCRIPTOR);
150: /* local variables | this | callback | */
151: Bytecode code = new Bytecode(cp, 3, 3);
152: // aload_0 // load this
153: code.addAload(0);
154: // aload_1 // load callback
155: code.addAload(1);
156: // putfield // put field "$JAVASSIST_CALLBACK" defined already
157: code.addOpcode(Opcode.PUTFIELD);
158: int field_index = cp.addFieldrefInfo(this _class_index,
159: HANDLER_FIELD_NAME, HANDLER_FIELD_DESCRIPTOR);
160: code.addIndex(field_index);
161: // return
162: code.addOpcode(Opcode.RETURN);
163: minfo.setCodeAttribute(code.toCodeAttribute());
164: minfo.setAccessFlags(AccessFlag.PUBLIC);
165: classfile.addMethod(minfo);
166: }
167:
168: private void addFieldHandledInterface(ClassFile classfile) {
169: String[] interfaceNames = classfile.getInterfaces();
170: String[] newInterfaceNames = new String[interfaceNames.length + 1];
171: System.arraycopy(interfaceNames, 0, newInterfaceNames, 0,
172: interfaceNames.length);
173: newInterfaceNames[newInterfaceNames.length - 1] = FIELD_HANDLED_TYPE_NAME;
174: classfile.setInterfaces(newInterfaceNames);
175: }
176:
177: private void addReadWriteMethods(ClassFile classfile)
178: throws CannotCompileException {
179: List fields = classfile.getFields();
180: for (Iterator field_iter = fields.iterator(); field_iter
181: .hasNext();) {
182: FieldInfo finfo = (FieldInfo) field_iter.next();
183: if ((finfo.getAccessFlags() & AccessFlag.STATIC) == 0
184: && (!finfo.getName().equals(HANDLER_FIELD_NAME))) {
185: // case of non-static field
186: if (filter.handleRead(finfo.getDescriptor(), finfo
187: .getName())) {
188: addReadMethod(classfile, finfo);
189: readableFields.put(finfo.getName(), finfo
190: .getDescriptor());
191: }
192: if (filter.handleWrite(finfo.getDescriptor(), finfo
193: .getName())) {
194: addWriteMethod(classfile, finfo);
195: writableFields.put(finfo.getName(), finfo
196: .getDescriptor());
197: }
198: }
199: }
200: }
201:
202: private void addReadMethod(ClassFile classfile, FieldInfo finfo)
203: throws CannotCompileException {
204: ConstPool cp = classfile.getConstPool();
205: int this _class_index = cp.getThisClassInfo();
206: String desc = "()" + finfo.getDescriptor();
207: MethodInfo minfo = new MethodInfo(cp, EACH_READ_METHOD_PREFIX
208: + finfo.getName(), desc);
209: /* local variables | target obj | each oldvalue | */
210: Bytecode code = new Bytecode(cp, 5, 3);
211: // aload_0
212: code.addAload(0);
213: // getfield // get each field
214: code.addOpcode(Opcode.GETFIELD);
215: int base_field_index = cp.addFieldrefInfo(this _class_index,
216: finfo.getName(), finfo.getDescriptor());
217: code.addIndex(base_field_index);
218: // aload_0
219: code.addAload(0);
220: // invokeinterface // invoke Enabled.getInterceptFieldCallback()
221: int enabled_class_index = cp
222: .addClassInfo(FIELD_HANDLED_TYPE_NAME);
223: code.addInvokeinterface(enabled_class_index,
224: GETFIELDHANDLER_METHOD_NAME,
225: GETFIELDHANDLER_METHOD_DESCRIPTOR, 1);
226: // ifnonnull
227: code.addOpcode(Opcode.IFNONNULL);
228: code.addIndex(4);
229: // *return // each type
230: addTypeDependDataReturn(code, finfo.getDescriptor());
231: // *store_1 // each type
232: addTypeDependDataStore(code, finfo.getDescriptor(), 1);
233: // aload_0
234: code.addAload(0);
235: // invokeinterface // invoke Enabled.getInterceptFieldCallback()
236: code.addInvokeinterface(enabled_class_index,
237: GETFIELDHANDLER_METHOD_NAME,
238: GETFIELDHANDLER_METHOD_DESCRIPTOR, 1);
239: // aload_0
240: code.addAload(0);
241: // ldc // name of the field
242: code.addLdc(finfo.getName());
243: // *load_1 // each type
244: addTypeDependDataLoad(code, finfo.getDescriptor(), 1);
245: // invokeinterface // invoke Callback.read*() // each type
246: addInvokeFieldHandlerMethod(classfile, code, finfo
247: .getDescriptor(), true);
248: // *return // each type
249: addTypeDependDataReturn(code, finfo.getDescriptor());
250:
251: minfo.setCodeAttribute(code.toCodeAttribute());
252: minfo.setAccessFlags(AccessFlag.PUBLIC);
253: classfile.addMethod(minfo);
254: }
255:
256: private void addWriteMethod(ClassFile classfile, FieldInfo finfo)
257: throws CannotCompileException {
258: ConstPool cp = classfile.getConstPool();
259: int this _class_index = cp.getThisClassInfo();
260: String desc = "(" + finfo.getDescriptor() + ")V";
261: MethodInfo minfo = new MethodInfo(cp, EACH_WRITE_METHOD_PREFIX
262: + finfo.getName(), desc);
263: /* local variables | target obj | each oldvalue | */
264: Bytecode code = new Bytecode(cp, 6, 3);
265: // aload_0
266: code.addAload(0);
267: // invokeinterface // enabled.getInterceptFieldCallback()
268: int enabled_class_index = cp
269: .addClassInfo(FIELD_HANDLED_TYPE_NAME);
270: code.addInvokeinterface(enabled_class_index,
271: GETFIELDHANDLER_METHOD_NAME,
272: GETFIELDHANDLER_METHOD_DESCRIPTOR, 1);
273: // ifnonnull (label1)
274: code.addOpcode(Opcode.IFNONNULL);
275: code.addIndex(9);
276: // aload_0
277: code.addAload(0);
278: // *load_1
279: addTypeDependDataLoad(code, finfo.getDescriptor(), 1);
280: // putfield
281: code.addOpcode(Opcode.PUTFIELD);
282: int base_field_index = cp.addFieldrefInfo(this _class_index,
283: finfo.getName(), finfo.getDescriptor());
284: code.addIndex(base_field_index);
285: code.growStack(-Descriptor.dataSize(finfo.getDescriptor()));
286: // return ;
287: code.addOpcode(Opcode.RETURN);
288: // aload_0
289: code.addAload(0);
290: // dup
291: code.addOpcode(Opcode.DUP);
292: // invokeinterface // enabled.getInterceptFieldCallback()
293: code.addInvokeinterface(enabled_class_index,
294: GETFIELDHANDLER_METHOD_NAME,
295: GETFIELDHANDLER_METHOD_DESCRIPTOR, 1);
296: // aload_0
297: code.addAload(0);
298: // ldc // field name
299: code.addLdc(finfo.getName());
300: // aload_0
301: code.addAload(0);
302: // getfield // old value of the field
303: code.addOpcode(Opcode.GETFIELD);
304: code.addIndex(base_field_index);
305: code.growStack(Descriptor.dataSize(finfo.getDescriptor()) - 1);
306: // *load_1
307: addTypeDependDataLoad(code, finfo.getDescriptor(), 1);
308: // invokeinterface // callback.write*(..)
309: addInvokeFieldHandlerMethod(classfile, code, finfo
310: .getDescriptor(), false);
311: // putfield // new value of the field
312: code.addOpcode(Opcode.PUTFIELD);
313: code.addIndex(base_field_index);
314: code.growStack(-Descriptor.dataSize(finfo.getDescriptor()));
315: // return
316: code.addOpcode(Opcode.RETURN);
317:
318: minfo.setCodeAttribute(code.toCodeAttribute());
319: minfo.setAccessFlags(AccessFlag.PUBLIC);
320: classfile.addMethod(minfo);
321: }
322:
323: private void transformInvokevirtualsIntoPutAndGetfields(
324: ClassFile classfile) throws CannotCompileException {
325: List methods = classfile.getMethods();
326: for (Iterator method_iter = methods.iterator(); method_iter
327: .hasNext();) {
328: MethodInfo minfo = (MethodInfo) method_iter.next();
329: String methodName = minfo.getName();
330: if (methodName.startsWith(EACH_READ_METHOD_PREFIX)
331: || methodName.startsWith(EACH_WRITE_METHOD_PREFIX)
332: || methodName.equals(GETFIELDHANDLER_METHOD_NAME)
333: || methodName.equals(SETFIELDHANDLER_METHOD_NAME)) {
334: continue;
335: }
336: CodeAttribute codeAttr = minfo.getCodeAttribute();
337: if (codeAttr == null) {
338: return;
339: }
340: CodeIterator iter = codeAttr.iterator();
341: while (iter.hasNext()) {
342: try {
343: int pos = iter.next();
344: pos = transformInvokevirtualsIntoGetfields(
345: classfile, iter, pos);
346: pos = transformInvokevirtualsIntoPutfields(
347: classfile, iter, pos);
348:
349: } catch (BadBytecode e) {
350: throw new CannotCompileException(e);
351: }
352: }
353: }
354: }
355:
356: private int transformInvokevirtualsIntoGetfields(
357: ClassFile classfile, CodeIterator iter, int pos) {
358: ConstPool cp = classfile.getConstPool();
359: int c = iter.byteAt(pos);
360: if (c != Opcode.GETFIELD) {
361: return pos;
362: }
363: int index = iter.u16bitAt(pos + 1);
364: String fieldName = cp.getFieldrefName(index);
365: String className = cp.getFieldrefClassName(index);
366: if ((!classfile.getName().equals(className))
367: || (!readableFields.containsKey(fieldName))) {
368: return pos;
369: }
370: String desc = "()" + (String) readableFields.get(fieldName);
371: int read_method_index = cp.addMethodrefInfo(cp
372: .getThisClassInfo(), EACH_READ_METHOD_PREFIX
373: + fieldName, desc);
374: iter.writeByte(Opcode.INVOKEVIRTUAL, pos);
375: iter.write16bit(read_method_index, pos + 1);
376: return pos;
377: }
378:
379: private int transformInvokevirtualsIntoPutfields(
380: ClassFile classfile, CodeIterator iter, int pos) {
381: ConstPool cp = classfile.getConstPool();
382: int c = iter.byteAt(pos);
383: if (c != Opcode.PUTFIELD) {
384: return pos;
385: }
386: int index = iter.u16bitAt(pos + 1);
387: String fieldName = cp.getFieldrefName(index);
388: String className = cp.getFieldrefClassName(index);
389: if ((!classfile.getName().equals(className))
390: || (!writableFields.containsKey(fieldName))) {
391: return pos;
392: }
393: String desc = "(" + (String) writableFields.get(fieldName)
394: + ")V";
395: int write_method_index = cp.addMethodrefInfo(cp
396: .getThisClassInfo(), EACH_WRITE_METHOD_PREFIX
397: + fieldName, desc);
398: iter.writeByte(Opcode.INVOKEVIRTUAL, pos);
399: iter.write16bit(write_method_index, pos + 1);
400: return pos;
401: }
402:
403: private static void addInvokeFieldHandlerMethod(
404: ClassFile classfile, Bytecode code, String typeName,
405: boolean isReadMethod) {
406: ConstPool cp = classfile.getConstPool();
407: // invokeinterface
408: int callback_type_index = cp
409: .addClassInfo(FIELD_HANDLER_TYPE_NAME);
410: if ((typeName.charAt(0) == 'L')
411: && (typeName.charAt(typeName.length() - 1) == ';')
412: || (typeName.charAt(0) == '[')) {
413: // reference type
414: int indexOfL = typeName.indexOf('L');
415: String type;
416: if (indexOfL == 0) {
417: // not array
418: type = typeName.substring(1, typeName.length() - 1);
419: type = type.replace('/', '.');
420: } else if (indexOfL == -1) {
421: // array of primitive type
422: // do nothing
423: type = typeName;
424: } else {
425: // array of reference type
426: type = typeName.replace('/', '.');
427: }
428: if (isReadMethod) {
429: code
430: .addInvokeinterface(
431: callback_type_index,
432: "readObject",
433: "(Ljava/lang/Object;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;",
434: 4);
435: // checkcast
436: code.addCheckcast(type);
437: } else {
438: code
439: .addInvokeinterface(
440: callback_type_index,
441: "writeObject",
442: "(Ljava/lang/Object;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
443: 5);
444: // checkcast
445: code.addCheckcast(type);
446: }
447: } else if (typeName.equals("Z")) {
448: // boolean
449: if (isReadMethod) {
450: code.addInvokeinterface(callback_type_index,
451: "readBoolean",
452: "(Ljava/lang/Object;Ljava/lang/String;Z)Z", 4);
453: } else {
454: code.addInvokeinterface(callback_type_index,
455: "writeBoolean",
456: "(Ljava/lang/Object;Ljava/lang/String;ZZ)Z", 5);
457: }
458: } else if (typeName.equals("B")) {
459: // byte
460: if (isReadMethod) {
461: code.addInvokeinterface(callback_type_index,
462: "readByte",
463: "(Ljava/lang/Object;Ljava/lang/String;B)B", 4);
464: } else {
465: code.addInvokeinterface(callback_type_index,
466: "writeByte",
467: "(Ljava/lang/Object;Ljava/lang/String;BB)B", 5);
468: }
469: } else if (typeName.equals("C")) {
470: // char
471: if (isReadMethod) {
472: code.addInvokeinterface(callback_type_index,
473: "readChar",
474: "(Ljava/lang/Object;Ljava/lang/String;C)C", 4);
475: } else {
476: code.addInvokeinterface(callback_type_index,
477: "writeChar",
478: "(Ljava/lang/Object;Ljava/lang/String;CC)C", 5);
479: }
480: } else if (typeName.equals("I")) {
481: // int
482: if (isReadMethod) {
483: code.addInvokeinterface(callback_type_index, "readInt",
484: "(Ljava/lang/Object;Ljava/lang/String;I)I", 4);
485: } else {
486: code.addInvokeinterface(callback_type_index,
487: "writeInt",
488: "(Ljava/lang/Object;Ljava/lang/String;II)I", 5);
489: }
490: } else if (typeName.equals("S")) {
491: // short
492: if (isReadMethod) {
493: code.addInvokeinterface(callback_type_index,
494: "readShort",
495: "(Ljava/lang/Object;Ljava/lang/String;S)S", 4);
496: } else {
497: code.addInvokeinterface(callback_type_index,
498: "writeShort",
499: "(Ljava/lang/Object;Ljava/lang/String;SS)S", 5);
500: }
501: } else if (typeName.equals("D")) {
502: // double
503: if (isReadMethod) {
504: code.addInvokeinterface(callback_type_index,
505: "readDouble",
506: "(Ljava/lang/Object;Ljava/lang/String;D)D", 5);
507: } else {
508: code.addInvokeinterface(callback_type_index,
509: "writeDouble",
510: "(Ljava/lang/Object;Ljava/lang/String;DD)D", 7);
511: }
512: } else if (typeName.equals("F")) {
513: // float
514: if (isReadMethod) {
515: code.addInvokeinterface(callback_type_index,
516: "readFloat",
517: "(Ljava/lang/Object;Ljava/lang/String;F)F", 4);
518: } else {
519: code.addInvokeinterface(callback_type_index,
520: "writeFloat",
521: "(Ljava/lang/Object;Ljava/lang/String;FF)F", 5);
522: }
523: } else if (typeName.equals("J")) {
524: // long
525: if (isReadMethod) {
526: code.addInvokeinterface(callback_type_index,
527: "readLong",
528: "(Ljava/lang/Object;Ljava/lang/String;J)J", 5);
529: } else {
530: code.addInvokeinterface(callback_type_index,
531: "writeLong",
532: "(Ljava/lang/Object;Ljava/lang/String;JJ)J", 7);
533: }
534: } else {
535: // bad type
536: throw new RuntimeException("bad type: " + typeName);
537: }
538: }
539:
540: private static void addTypeDependDataLoad(Bytecode code,
541: String typeName, int i) {
542: if ((typeName.charAt(0) == 'L')
543: && (typeName.charAt(typeName.length() - 1) == ';')
544: || (typeName.charAt(0) == '[')) {
545: // reference type
546: code.addAload(i);
547: } else if (typeName.equals("Z") || typeName.equals("B")
548: || typeName.equals("C") || typeName.equals("I")
549: || typeName.equals("S")) {
550: // boolean, byte, char, int, short
551: code.addIload(i);
552: } else if (typeName.equals("D")) {
553: // double
554: code.addDload(i);
555: } else if (typeName.equals("F")) {
556: // float
557: code.addFload(i);
558: } else if (typeName.equals("J")) {
559: // long
560: code.addLload(i);
561: } else {
562: // bad type
563: throw new RuntimeException("bad type: " + typeName);
564: }
565: }
566:
567: private static void addTypeDependDataStore(Bytecode code,
568: String typeName, int i) {
569: if ((typeName.charAt(0) == 'L')
570: && (typeName.charAt(typeName.length() - 1) == ';')
571: || (typeName.charAt(0) == '[')) {
572: // reference type
573: code.addAstore(i);
574: } else if (typeName.equals("Z") || typeName.equals("B")
575: || typeName.equals("C") || typeName.equals("I")
576: || typeName.equals("S")) {
577: // boolean, byte, char, int, short
578: code.addIstore(i);
579: } else if (typeName.equals("D")) {
580: // double
581: code.addDstore(i);
582: } else if (typeName.equals("F")) {
583: // float
584: code.addFstore(i);
585: } else if (typeName.equals("J")) {
586: // long
587: code.addLstore(i);
588: } else {
589: // bad type
590: throw new RuntimeException("bad type: " + typeName);
591: }
592: }
593:
594: private static void addTypeDependDataReturn(Bytecode code,
595: String typeName) {
596: if ((typeName.charAt(0) == 'L')
597: && (typeName.charAt(typeName.length() - 1) == ';')
598: || (typeName.charAt(0) == '[')) {
599: // reference type
600: code.addOpcode(Opcode.ARETURN);
601: } else if (typeName.equals("Z") || typeName.equals("B")
602: || typeName.equals("C") || typeName.equals("I")
603: || typeName.equals("S")) {
604: // boolean, byte, char, int, short
605: code.addOpcode(Opcode.IRETURN);
606: } else if (typeName.equals("D")) {
607: // double
608: code.addOpcode(Opcode.DRETURN);
609: } else if (typeName.equals("F")) {
610: // float
611: code.addOpcode(Opcode.FRETURN);
612: } else if (typeName.equals("J")) {
613: // long
614: code.addOpcode(Opcode.LRETURN);
615: } else {
616: // bad type
617: throw new RuntimeException("bad type: " + typeName);
618: }
619: }
620:
621: }
|