我有一个tensorflow op,它可以在非XLA情况下使用,但不能在XLA中使用。
如果op在用修饰的函数中使用@tf.function(experimental_compile=True)
,则tensorflow会给出错误,因为op没有使用XLA的实现。
我要实现的解决方案是以某种方式在op实现内部检测用户是否指定了@tf.function(experimental_compile=True)
。
因此,如果用户未指定此选项,那么我将使用自定义GPU非XLA实现。另一方面,如果用户指定了此选项,则我将仅使用默认实现,然后让XLA本身在其上实现操作融合优化。
那么,有谁知道是否有可能在op实现内部检测op是否已在@tf.function(experimental_compile=True)
装饰器的函数内部使用?
是的。
如果检查tensorflow源代码的tf.function
实现,则可以看到它是一个返回Function
对象的包装器。_experimental_compile
装饰该对象时,该对象将暴露给您的方法。因此,您可以在方法定义内检查您的方法是否具有此属性,然后编写有关是否执行xla敏感操作的条件语句。
例子:
# relevant packages
import tensorflow as tf
from tensorflow.python.platform import test as test_lib
# a fake op that doesn't have an xla implementation
def xla_sensitive_op(tensor):
"""An op that won't run with experimental compiling enabled."""
# do tensorflow stuff
return tensor
# a fake op that has an xla implementation
def xla_compliant_op(tensor):
"""An op that will run with experimental compiling enabled."""
# do tensorflow stuff
return 2 * tensor
测试未修饰的功能
def maybe_decorated_func(tensor):
compile_on = None
if hasattr(maybe_decorated_func, '_experimental_compile'):
compile_on = maybe_decorated_func._experimental_compile
if compile_on:
return xla_compliant_op(tensor)
else:
return xla_sensitive_op(tensor)
class TestXLASensitiveOp(test_lib.TestCase):
def setUp(self):
self.tensor = tf.constant([1, 2, 3])
def test_func_is_not_decorated(self):
self.assertAllEqual(
maybe_decorated_func(self.tensor),
tf.constant([1, 2, 3])) # <= executes 2nd branch of if/else
if __name__ == "__main__":
test_lib.main()
# [ RUN ] TestXLASensitiveOp.test_func_is_not_decorated
# [ OK ] TestXLASensitiveOp.test_func_is_not_decorated
测试修饰的函数,但未启用实验性编译
@tf.function
def maybe_decorated_func(tensor):
compile_on = None
if hasattr(maybe_decorated_func, '_experimental_compile'):
compile_on = maybe_decorated_func._experimental_compile
if compile_on:
return xla_compliant_op(tensor)
else:
return xla_sensitive_op(tensor)
class TestXLASensitiveOp(test_lib.TestCase):
def setUp(self):
self.tensor = tf.constant([1, 2, 3])
def test_func_is_decorated_but_no_compile(self):
self.assertAllEqual(
maybe_decorated_func(self.tensor),
tf.constant([1, 2, 3])) # <= executes 2nd branch of if/else
if __name__ == "__main__":
test_lib.main()
# [ RUN ] TestXLASensitiveOp.test_func_is_decorated_but_no_compile
# [ OK ] TestXLASensitiveOp.test_func_is_decorated_but_no_compile
在启用实验性编译的情况下测试修饰的函数
@tf.function(experimental_compile=True)
def maybe_decorated_func(tensor):
compile_on = None
if hasattr(maybe_decorated_func, '_experimental_compile'):
compile_on = maybe_decorated_func._experimental_compile
if compile_on:
return xla_compliant_op(tensor)
else:
return xla_sensitive_op(tensor)
class TestXLASensitiveOp(test_lib.TestCase):
def setUp(self):
self.tensor = tf.constant([1, 2, 3])
def test_func_is_decorated_with_compile(self):
self.assertAllEqual(
maybe_decorated_func(self.tensor),
tf.constant([2, 4, 6])) # <= executes 1st branch of if/else
if __name__ == "__main__":
test_lib.main()
# [ RUN ] TestXLASensitiveOp.test_func_is_decorated_with_compile
# [ OK ] TestXLASensitiveOp.test_func_is_decorated_with_compile
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句