Skip to content

Commit e7a97eb

Browse files
committed
Add tests & fixes for time derivatives
1 parent 3fa4237 commit e7a97eb

File tree

3 files changed

+82
-22
lines changed

3 files changed

+82
-22
lines changed

src/analysis/lattice.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ struct Incidence
139139
if is_non_incidence_type(type)
140140
throw(DomainError(type, "Invalid type for Incidence"))
141141
end
142-
if !isa(row, IncidenceVector)
142+
if !isa(row, SparseVector)
143143
vec, row = row, _zero_row()
144144
for (i, val) in enumerate(vec)
145145
row[i] = val
146146
end
147+
else
148+
row = convert(IncidenceVector, row)
147149
end
148150
time = row[1]
149151
if in(time, (linear_time_dependent, linear_time_and_state_dependent))

src/analysis/refiner.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ end
7979
function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vector{VarEqClassification}, Nothing}, varkinds::Union{Vector{Intrinsics.VarKind}, Nothing}, inc::Union{Incidence, Const})
8080
isa(inc, Const) && return Const(zero(inc.val))
8181
r = _zero_row()
82-
function get_or_make_diff(v_offset::Int)
83-
v = v_offset - 1
82+
function get_or_make_diff(i::Int)
83+
v = i - 1
8484
var_to_diff[v] !== nothing && return var_to_diff[v] + 1
8585
dv = add_vertex!(var_to_diff)
8686
if varclassification !== nothing
@@ -94,21 +94,20 @@ function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vec
9494
end
9595
base = isa(inc.typ, Const) ? Const(zero(inc.typ.val)) : inc.typ
9696
indices = rowvals(inc.row)
97-
for (v_offset, coeff) in zip(indices, nonzeros(inc.row))
98-
if v_offset == 1
99-
# t
97+
for (i, coeff) in zip(indices, nonzeros(inc.row))
98+
if i == 1 # time
10099
if isa(coeff, Float64) # known constant coefficient
101-
# Do not set r[v_offset]; d/dt t = 1
100+
# Do not set r[i]; d/dt t = 1
102101
if isa(base, Const)
103102
base = Const(base.val + coeff)
104103
end
105104
elseif coeff.nonlinear
106-
r[v_offset] = nonlinear
105+
r[i] = nonlinear
107106
else
108107
@assert !coeff.time_dependent # should be nonlinear if time-dependent
109108
if coeff.state_dependent # e.g. u₁ * t
110109
# State dependence will not be eliminated because of the chain rule.
111-
r[v_offset] = coeff
110+
r[i] = coeff
112111
else # unknown constant coefficient
113112
if isa(base, Const)
114113
# We are adding an unknown but constant value to the
@@ -121,16 +120,30 @@ function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vec
121120
end
122121
if isa(coeff, Float64)
123122
# Linear with a known constant coefficient, just add to the derivative
124-
r[get_or_make_diff(v_offset)] += coeff
123+
r[get_or_make_diff(i)] += coeff
125124
elseif !coeff.state_dependent && !coeff.time_dependent
126125
# Linear with an unknown constant coefficient.
127-
r[get_or_make_diff(v_offset)] = coeff
126+
r[get_or_make_diff(i)] = coeff
128127
elseif coeff.nonlinear
129-
r[v_offset] = nonlinear
130-
r[get_or_make_diff(v_offset)] = nonlinear
128+
r[i] = nonlinear
129+
r[get_or_make_diff(i)] = nonlinear
130+
# derivative may yield constant terms (only if time-dependent,
131+
# but we conservatively assume nonlinear not to be time-independent)
132+
base = widenconst(base)
131133
else # time- or state-dependent linear coefficient
132-
r[v_offset] = coeff
133-
r[get_or_make_diff(v_offset)] = coeff
134+
if !coeff.state_dependent && coeff.time_dependent && length(nonzeros(inc.row)) == 2 && inc.row[1] nonlinear
135+
# `f(∝t, ∝ₜuᵢ)` is of the form `(a + bt)(c + duᵢ)`, we can simplify to `(a + bt)t∂ₜuᵢ + b(c + duᵢ)`,
136+
# yielding `∝ₜ∂uᵢ + bc + ∝uᵢ`. (note that in this case we'll also need to widen `base`, as done further below)
137+
r[i] = linear
138+
r[get_or_make_diff(i)] = linear_time_dependent
139+
else
140+
r[i] = coeff
141+
r[get_or_make_diff(i)] = coeff
142+
end
143+
if coeff.time_dependent
144+
# The product rule with a time factor may yield constant terms.
145+
base = widenconst(base)
146+
end
134147
end
135148
end
136149
return Incidence(base, r)

test/incidence.jl

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ postwalk(f, ex) = walk(ex, x -> postwalk(f, x), f)
2323
x + u₃
2424
end
2525
@infer_incidence u₁ + u₂ true # a second argument of `true` makes it return the IR, not the Incidence
26-
@infer_incidence exp(u₂) + u₁ # will create u₂ as the second continuous variable
26+
@infer_incidence 1/u₂ + u₁ # will create u₂ as the second continuous variable
2727
2828
Return the `Incidence` object after inferring the structure of the provided code,
2929
substituting any variables starting by 'u'. Variables are created as `continuous()`
@@ -174,6 +174,11 @@ dependencies(row) = sort(rowvals(row) .=> nonzeros(row), by = first)
174174
@test dependencies(incidence.row) == [2 => 1]
175175
@test incidence == incidence"u₁"
176176

177+
incidence = @infer_incidence 3.0 *ᵢ u₁
178+
@test incidence.typ === Const(0.0)
179+
@test dependencies(incidence.row) == [2 => 3.0]
180+
@test incidence == incidence"3.0u₁"
181+
177182
incidence = @infer_incidence u₁ +ᵢ u₂
178183
@test incidence.typ === Const(0.0)
179184
@test dependencies(incidence.row) == [2 => 1, 3 => 1]
@@ -274,26 +279,66 @@ dependencies(row) = sort(rowvals(row) .=> nonzeros(row), by = first)
274279
@test dependencies(incidence.row) == [2 => nonlinear, 3 => linear_state_dependent, 4 => linear_state_dependent]
275280
@test incidence == incidence"f(u₁, ∝ₛu₂, ∝ₛu₃)"
276281

277-
incidence = @infer_incidence exp(u₁)
282+
incidence = @infer_incidence 1/u₁
278283
@test dependencies(incidence.row) == [2 => nonlinear]
279284
@test incidence == incidence"a + f(u₁)"
280285

281-
incidence = @infer_incidence t * exp(u₁)
286+
incidence = @infer_incidence t * (1/u₁)
282287
@test dependencies(incidence.row) == [1 => linear_state_dependent, 2 => nonlinear]
283288
@test incidence == incidence"a + f(∝ₛt, u₁)"
284289

285-
incidence = @infer_incidence u₁ * exp(t)
290+
incidence = @infer_incidence u₁ * (1/t)
286291
@test dependencies(incidence.row) == [1 => nonlinear, 2 => linear_time_dependent]
287292
@test incidence == incidence"a + f(t, ∝ₜu₁)"
288293

289-
incidence = @infer_incidence u₁ * exp(t + u₂)
294+
incidence = @infer_incidence u₁ * (1/(t + u₂))
290295
@test dependencies(incidence.row) == [1 => nonlinear, 2 => linear_time_and_state_dependent, 3 => nonlinear]
291296
@test incidence == incidence"a + f(t, ∝ₜₛu₁, u₂)"
292297

293-
incidence = @infer_incidence atan(u₁, u₂)
298+
incidence = @infer_incidence 1/(u₁ * u₂)
294299
@test dependencies(incidence.row) == [2 => nonlinear, 3 => nonlinear]
295300
@test incidence == incidence"a + f(u₁, u₂)"
296301
end
297-
end;
302+
303+
@testset "Time derivatives" begin
304+
incidence = @infer_incidence ddt(3.0 *ᵢ t)
305+
@test incidence == incidence"3.0"
306+
307+
incidence = @infer_incidence ddt(3.0 *ᵢ t +5.0)
308+
@test incidence == incidence"3.0"
309+
310+
incidence = @infer_incidence ddt(3.0 * t)
311+
@test incidence == incidence"a"
312+
313+
incidence = @infer_incidence ddt(u₁)
314+
@test incidence == incidence"u₂"
315+
316+
incidence = @infer_incidence ddt(1.0 +ᵢ u₁)
317+
@test incidence == incidence"u₂"
318+
319+
incidence = @infer_incidence u₁ +ddt(u₁)
320+
@test incidence == incidence"u₁ + u₂"
321+
322+
incidence = @infer_incidence ddt(u₁ *ᵢ u₂)
323+
@test incidence == incidence"f(∝ₛu₁, ∝ₛu₂, ∝ₛu₃, ∝ₛu₄)"
324+
325+
incidence = @infer_incidence ddt(u₁ *ᵢ t)
326+
@test incidence == incidence"a + ∝u₁ + f(∝ₛt, ∝ₜu₂)"
327+
328+
incidence = @infer_incidence ddt(u₁ *ᵢ u₁)
329+
@test incidence == incidence"a + f(u₁, u₂)"
330+
# Note that the constant term may be removed if we
331+
# model nonlinear time-independent incidences.
332+
333+
incidence = @infer_incidence ddt(1/u₁)
334+
@test incidence == incidence"a + f(u₁, u₂)"
335+
336+
incidence = @infer_incidence ddt((2.0 +ᵢ u₁) *ᵢ (3.0 +ᵢ u₂))
337+
@test incidence == incidence"f(∝ₛu₁, ∝ₛu₂, ∝ₛu₃, ∝ₛu₄)"
338+
339+
incidence = @infer_incidence ddt((2.0 +ᵢ u₁) *ᵢ (3.0 +ᵢ t))
340+
@test incidence == incidence"a + ∝u₁ + f(∝ₛt, ∝ₜu₂)"
341+
end
342+
end
298343

299344
end

0 commit comments

Comments
 (0)