Comments (6)
Here's the LLVM code I'm seeing:
julia> @code_llvm f(π/4)
define double @julia_f_22987(double) {
top:
%1 = call double inttoptr (i64 13147291680 to double (double)*)(double %0)
%2 = call double inttoptr (i64 13147331552 to double (double)*)(double %0)
%3 = call double inttoptr (i64 13147347792 to double (double)*)(double %0)
%4 = fcmp ord double %2, 0.000000e+00
%5 = fcmp uno double %0, 0.000000e+00
%6 = or i1 %4, %5
br i1 %6, label %pass, label %fail
fail: ; preds = %top
%7 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %7, i32 1)
unreachable
pass: ; preds = %top
%8 = call double @pow(double %2, double 3.000000e+00)
%9 = fcmp ord double %3, 0.000000e+00
%10 = or i1 %9, %5
br i1 %10, label %pass4, label %fail3
fail3: ; preds = %pass
%11 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %11, i32 1)
unreachable
pass4: ; preds = %pass
%12 = call double @pow(double %3, double 3.000000e+00)
%13 = fadd double %8, %12
%14 = fdiv double %1, %13
ret double %14
}
julia> @code_llvm f(π/4 + ɛ)
define void @julia_f_22763(%Dual* sret, %Dual*) {
top:
%2 = bitcast %Dual* %1 to double*
%3 = load double* %2, align 8
%4 = call double inttoptr (i64 13147291680 to double (double)*)(double %3)
%5 = load double* %2, align 8
%6 = call double inttoptr (i64 13147291680 to double (double)*)(double %5)
%7 = load double* %2, align 8
%8 = call double inttoptr (i64 13147331552 to double (double)*)(double %7)
%9 = load double* %2, align 8
%10 = call double inttoptr (i64 13147347792 to double (double)*)(double %9)
%11 = fcmp ord double %8, 0.000000e+00
%12 = fcmp uno double %7, 0.000000e+00
%13 = or i1 %11, %12
br i1 %13, label %pass, label %fail
fail: ; preds = %top
%14 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %14, i32 1)
unreachable
pass: ; preds = %top
%15 = fcmp ord double %10, 0.000000e+00
%16 = fcmp uno double %9, 0.000000e+00
%17 = or i1 %15, %16
br i1 %17, label %pass2, label %fail1
fail1: ; preds = %pass
%18 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %18, i32 1)
unreachable
pass2: ; preds = %pass
%19 = bitcast %Dual* %1 to double*
%20 = alloca %Dual, align 8
%21 = alloca %Dual, align 8
%22 = getelementptr inbounds %Dual* %1, i64 0, i32 1
%23 = load double* %22, align 8
%24 = fmul double %23, -1.000000e+00
%25 = insertvalue %Dual undef, double %8, 0
%26 = fmul double %10, %24
%27 = insertvalue %Dual %25, double %26, 1
store %Dual %27, %Dual* %21, align 8
call void @julia_power_by_squaring_22764(%Dual* sret %20, %Dual* %21, i64 3)
%28 = load %Dual* %20, align 8
%29 = load double* %19, align 8
%30 = call double inttoptr (i64 13147347792 to double (double)*)(double %29)
%31 = load double* %19, align 8
%32 = call double inttoptr (i64 13147331552 to double (double)*)(double %31)
%33 = fcmp ord double %30, 0.000000e+00
%34 = fcmp uno double %29, 0.000000e+00
%35 = or i1 %33, %34
br i1 %35, label %pass4, label %fail3
fail3: ; preds = %pass2
%36 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %36, i32 1)
unreachable
pass4: ; preds = %pass2
%37 = fcmp ord double %32, 0.000000e+00
%38 = fcmp uno double %31, 0.000000e+00
%39 = or i1 %37, %38
br i1 %39, label %pass6, label %fail5
fail5: ; preds = %pass4
%40 = load %jl_value_t** @jl_domain_exception, align 8
call void @jl_throw_with_superfluous_argument(%jl_value_t* %40, i32 1)
unreachable
pass6: ; preds = %pass4
%41 = alloca %Dual, align 8
%42 = alloca %Dual, align 8
%43 = extractvalue %Dual %28, 0
%44 = extractvalue %Dual %28, 1
%sunkaddr = ptrtoint %Dual* %1 to i64
%sunkaddr13 = add i64 %sunkaddr, 8
%sunkaddr14 = inttoptr i64 %sunkaddr13 to double*
%45 = load double* %sunkaddr14, align 8
%46 = insertvalue %Dual undef, double %30, 0
%47 = fmul double %32, %45
%48 = insertvalue %Dual %46, double %47, 1
store %Dual %48, %Dual* %42, align 8
call void @julia_power_by_squaring_22764(%Dual* sret %41, %Dual* %42, i64 3)
%49 = load %Dual* %41, align 8
%50 = extractvalue %Dual %49, 0
%51 = extractvalue %Dual %49, 1
%52 = load double* %sunkaddr14, align 8
%53 = fmul double %6, %52
%54 = fadd double %43, %50
%55 = fadd double %44, %51
%56 = fdiv double %4, %54
%57 = insertvalue %Dual undef, double %56, 0
%58 = fmul double %53, %54
%59 = fmul double %4, %55
%60 = fsub double %58, %59
%61 = fmul double %54, %54
%62 = fdiv double %60, %61
%63 = insertvalue %Dual %57, double %62, 1
store %Dual %63, %Dual* %0, align 8
ret void
}
It seems that the real version is calling the pow
intrinsic while the Julia version is calling power_by_squaring
. They're getting the same answer, which is reassuring, but it's a bit crazy that doing the power by squaring thing could be faster than calling pow
.
from forwarddiff.jl.
Also, installing all of that code was insanely easy. Great work packaging, everyone!
from forwarddiff.jl.
Ah, that's the difference of course, I haven't implemented ^
for Dual numbers so it falls back to power_by_squaring
.
from forwarddiff.jl.
Writing out the powers as explicit multiplication gives:
f(x) = exp(x) / (cos(x)*cos(x)*cos(x) + sin(x)*sin(x)*sin(x))
julia> @benchmark f(π/4 + ɛ)
================ Benchmark Results ========================
Time per evaluation: 90.01 ns [88.58 ns, 91.44 ns]
julia> @benchmark f(π/4 + im)
================ Benchmark Results ========================
Time per evaluation: 604.65 ns [602.47 ns, 606.82 ns]
julia> @benchmark f(π/4)
================ Benchmark Results ========================
Time per evaluation: 81.51 ns [81.30 ns, 81.73 ns]
where reals are a bit faster than duals and complex are much(!) slower.
from forwarddiff.jl.
Okay, this is the reason I think:
julia> @benchmark sin(π/4)^3
================ Benchmark Results ========================
Time per evaluation: 87.58 ns [86.45 ns, 88.72 ns]
julia> @benchmark sin(π/4)^3.0
================ Benchmark Results ========================
Time per evaluation: 22.57 ns [22.27 ns, 22.87 ns]
Changing the exponents in f
to floats like:
f(x) = exp(x) / (cos(x)^3.0 + sin(x)^3.0)
makes the float version ~4 times faster than dual.
You heard it here first folks, start writing your exponents as floats!
from forwarddiff.jl.
Closing this, since it turned out to be irrelevant to Dual numbers or Automatic Differentiation, sorry for making noise.
from forwarddiff.jl.
Related Issues (20)
- AD in-place instead of broadcast HOT 1
- Is the mutating code the problem here? How could I debug this? HOT 2
- Rationals and Modulo
- `NaNMath` (and `SpecialFunctions`) as extensions? HOT 5
- Broken external link
- `construct_seeds` for types where `typeof(one(T)) !=T` is broken HOT 1
- incorrect 2nd derivative of complex exponential HOT 2
- Can you take derivative of complicated function whose symbolic form is not explicit or not known?
- Cancellation with sparse arrays HOT 5
- Implement hessian! for scalar x
- Implement gammalogccdf for ForwardDiff HOT 1
- `ForwardDiff.jacobian` throws error for `fft` HOT 1
- Correctly forming nested dual numbers. HOT 8
- Derivative of a function of derivatives HOT 7
- Symbolics.jl compatibility HOT 1
- Support derivative(f, ::Complex)
- `ForwardDiff` fails to compute correct derivative HOT 3
- Incorrect Hessian by `exp` function HOT 1
- Method ambiguities reported by Aqua HOT 3
- Document internals? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from forwarddiff.jl.